TridentDataspec¶
Configuration¶
A TridentDataspec
class encapsulates the configuration for data handling in a machine learning workflow. It manages various aspects of data processing including dataset instantiation, preprocessing, dataloading, and evaluation.
Configuration Keys
dataset
: Specifies how the dataset should be instantiated.dataloader
: Defines the instantiation of theDataLoader
.preprocessing
(optional): Details the methods or function calls for dataset preprocessing.evaluation
(optional): Outlines any post-processing steps and metrics for dataset evaluation.misc
(optional): Reserved for miscellaneous settings that do not fit under other keys.
dataset¶
The dataset
instantiates a Dataset that is compatible with a PyTorch DataLoader. Any preprocessing should be defined in the corresponding preprocessing
configuration of the TridentDataspec
.
mnli_train:
dataset: # required, config on how to instantiate dataset
_target_: datasets.load_dataset
path: glue
name: mnli
split: train
preprocessing¶
The preprocessing
key in the configuration details the steps for preparing the dataset. It includes two special keys, method
and apply
, each holding dictionaries for specific preprocessing actions.
method
: Contains dictionaries of class methods along with their keyword arguments. These are typically methods of the dataset class.apply
: Comprises dictionaries of user-defined functions, along with their keyword arguments, to be applied to the dataset. Be mindful that functions ofapply
, unlike most other keys, typically does not take_partial_: true
The preprocessing fucntions take the Dataset
as the first positional argument. The functions are called in order of the configuration. Note that "method"
is a convenience keyword which can also be achieved by pointing to the classmethod in "_target_"
of an "apply"
function.
Example Configuration
preprocessing:
method:
map: # dataset.map of huggingface `datasets.arrow_dataset.Dataset`
function:
_target_: src.tasks.text_classification.processing.preprocess_fn
_partial_: true
column_names:
text: premise
text_pair: hypothesis
tokenizer:
_partial_: true
_target_: transformers.tokenization_utils_base.PreTrainedTokenizerBase.__call__
self:
_target_: transformers.AutoTokenizer.from_pretrained
pretrained_model_name_or_path: ${module.model.pretrained_model_name_or_path}
padding: false
truncation: true
max_length: 128
# unify output format of MNLI and XNLI
set_format:
columns:
- "input_ids"
- "attention_mask"
- "label"
apply:
example_function:
_target_: mod.package.example_function
# ...
dataloader¶
The DataLoader configuration (configs/dataspec/dataloader/default.yaml) is preset with reasonable defaults, accommodating typical use cases.
Example Configuration
_target_: torch.utils.data.dataloader.DataLoader
collate_fn:
_target_: transformers.data.data_collator.DataCollatorWithPadding
tokenizer:
_target_: transformers.AutoTokenizer.from_pretrained
pretrained_model_name_or_path: ${module.model.pretrained_model_name_or_path}
max_length: ???
batch_size: 32
pin_memory: true
shuffle: false
num_workers: 4
evaluation¶
The logic of evaluation is defined in ./configs/dataspec/evaluation/text_classification.yaml
. It is common to define evaluation per type of task.
evaluation
configuration segments into the fields prepare
, step_outputs
, and metrics
.
See also
trident.utils.types.EvaluationDict
prepare¶
prepare
defines functions called on the batch
, the model outputs
, or the collected step_outputs
.
The TridentModule
hands the below keywords to facilitate evaluation. Since the TridentModule
extends the LightningModule, useful attributes like trainer
and trainer.datamodule
are available at runtime.
Example Configuration
prepare:
# takes (trident_module: TridentModule, batch: dict, split: Split, dataset_name: str) -> dict
batch: null
# takes (trident_module: TridentModule, outputs: dict, batch: dict, split: Split, dataset_name: str) -> dict
outputs:
_partial_: true
_target_: src.tasks.text_classification.evaluation.get_preds
# takes (trident_module: TridentModule, step_outputs: dict, split: Split, dataset_name: str) -> dict
step_outputs: null
where get_preds
is defined as follows and merely adds
def get_preds(outputs: dict, *args, **kwargs) -> dict:
outputs["preds"] = outputs["logits"].argmax(dim=-1)
return outputs
See also
trident.utils.enums.Split
, trident.utils.types.PrepareDict
step_outputs¶
step_outputs
defines what keys are collected from a batch
or outputs
dictionary, per step, into the flattened outputs dict
per evaluation dataloader. The flattened dictionary then holds the corresponding key-value pairs as input to the prepare_step_outputs
function, which ultimately serves at input to metrics computed at the end of an evaluation loop.
Note
trident ensures that after each evaluation loop, lists of np.ndarray
s torch.Tensor
s are correctly stacked to single array with appropriate dimensions.
Example Configuration
# Which keys/attributes are supposed to be collected from `outputs` and `batch`
step_outputs:
# can be a str
batch: labels
# or a list[str]
outputs:
- "preds"
- "logits"
See also
trident.utils.flatten_dict()
metrics¶
metrics
denotes a dictionary for all evaluated metrics. For instance, a metric such as acc
may contain:
metric
: how to instantiate the metric; typically apartial
function; must return aCallable
.compute_on
: Eithereval_step
orepoch_end
, with the latter being the default.kwargs
: A custom syntax to fetchkwargs
ofmetric
from one of the following:[trident_module, outputs, batch, cfg]
. -outputs
refers to the modeloutputs
whencompute_on
is set toeval_step
and tostep_outputs
whencompute_on
is set toepoch_end
.
- In the NLI example:
The keyword
preds
fortorchmetrics.functional.accuracy
is sourced fromoutputs["preds"]
.The keyword
target
fortorchmetrics.functional.accuracy
is sourced fromoutputs["labels"]
.
Example Configuration
metrics:
# name of the metric used eg for logging
acc:
# instructions to instantiate metric, preferrably torchmetrics.Metric
metric:
_partial_: true
_target_: torchmetrics.functional.accuracy
# either "eval_step" or "epoch_end", defaults to "epoch_end"
compute_on: "epoch_end"
kwargs:
preds: "outputs.preds"
target: "outputs.labels"
Note
It is also possible to use metrics
for other actions than logging evaluation. The metric function should then return None
.
In the below case, we construct a “metric” that logs predictions instead. A metric is not logged to wandb
if the function returns None
.
def store_predictions(
trident_module: TridentModule,
preds: torch.Tensor,
dataset_name: str,
directory: str,
*args,
**kwargs,
):
trainer = trident_module.trainer
epoch = trainer.current_epoch
p = pd.DataFrame(preds.cpu())
p.columns = ["prediction"]
path = Path(directory).joinpath(f"dataset={dataset_name}_epoch={epoch}.csv")
p.to_csv(str(path), index_label="ids")
And the corresponding metric
configuration
metrics:
store_predictions:
metric:
_partial_: true
_target_: $PATH_TO_FUNCTION.store_predictions
directory: ${hydra:runtime.output_dir}
compute_on: "epoch_end"
kwargs:
trident_module: "trident_module"
preds: "outputs.preds"
dataset_name: "dataset_name"
API¶
Properties¶
cfg¶
The cfg: omegaconf.DictConfig
holds the dataspec configuration.
See also
dataset¶
The dataset as declared in cfg.dataset
and preprocssed by cfg.preprocessing
.
evaluation¶
The evaluation as declared in cfg.evaluation
.
Methods¶
get_dataloader¶
- TridentDataspec.get_dataloader()[source]
Creates a DataLoader for the dataset.
- Parameters:
signature_columns – Columns to be used in the dataloader. Defaults to None.
passed (If)
misc. (removes unused columns if configured in)
- Returns:
The DataLoader configured as per the specified settings.
- Return type:
DataLoader