TridentDataspec¶
Configuration¶
A TridentDataspec class encapsulates the configuration for data handling in a machine learning workflow. It manages various aspects of data processing including dataset instantiation, preprocessing, dataloading, and evaluation.
Configuration Keys
dataset: Specifies how the dataset should be instantiated.dataloader: Defines the instantiation of theDataLoader.preprocessing(optional): Details the methods or function calls for dataset preprocessing.evaluation(optional): Outlines any post-processing steps and metrics for dataset evaluation.misc(optional): Reserved for miscellaneous settings that do not fit under other keys.
dataset¶
The dataset instantiates a Dataset that is compatible with a PyTorch DataLoader. Any preprocessing should be defined in the corresponding preprocessing configuration of the TridentDataspec.
mnli_train:
dataset: # required, config on how to instantiate dataset
_target_: datasets.load_dataset
path: glue
name: mnli
split: train
preprocessing¶
The preprocessing key in the configuration details the steps for preparing the dataset. It includes two special keys, method and apply, each holding dictionaries for specific preprocessing actions.
method: Contains dictionaries of class methods along with their keyword arguments. These are typically methods of the dataset class.apply: Comprises dictionaries of user-defined functions, along with their keyword arguments, to be applied to the dataset. Be mindful that functions ofapply, unlike most other keys, typically does not take_partial_: true
The preprocessing fucntions take the Dataset as the first positional argument. The functions are called in order of the configuration. Note that "method" is a convenience keyword which can also be achieved by pointing to the classmethod in "_target_" of an "apply" function.
Example Configuration
preprocessing:
method:
map: # dataset.map of huggingface `datasets.arrow_dataset.Dataset`
function:
_target_: src.tasks.text_classification.processing.preprocess_fn
_partial_: true
column_names:
text: premise
text_pair: hypothesis
tokenizer:
_partial_: true
_target_: transformers.tokenization_utils_base.PreTrainedTokenizerBase.__call__
self:
_target_: transformers.AutoTokenizer.from_pretrained
pretrained_model_name_or_path: ${module.model.pretrained_model_name_or_path}
padding: false
truncation: true
max_length: 128
# unify output format of MNLI and XNLI
set_format:
columns:
- "input_ids"
- "attention_mask"
- "label"
apply:
example_function:
_target_: mod.package.example_function
# ...
dataloader¶
The DataLoader configuration (configs/dataspec/dataloader/default.yaml) is preset with reasonable defaults, accommodating typical use cases.
Example Configuration
_target_: torch.utils.data.dataloader.DataLoader
collate_fn:
_target_: transformers.data.data_collator.DataCollatorWithPadding
tokenizer:
_target_: transformers.AutoTokenizer.from_pretrained
pretrained_model_name_or_path: ${module.model.pretrained_model_name_or_path}
max_length: ???
batch_size: 32
pin_memory: true
shuffle: false
num_workers: 4
evaluation¶
The logic of evaluation is defined in ./configs/dataspec/evaluation/text_classification.yaml. It is common to define evaluation per type of task.
evaluation configuration segments into the fields prepare, step_outputs, and metrics.
See also
trident.utils.types.EvaluationDict
prepare¶
prepare defines functions called on the batch, the model outputs, or the collected step_outputs.
The TridentModule hands the below keywords to facilitate evaluation. Since the TridentModule extends the LightningModule, useful attributes like trainer and trainer.datamodule are available at runtime.
Example Configuration
prepare:
# takes (trident_module: TridentModule, batch: dict, split: Split, dataset_name: str) -> dict
batch: null
# takes (trident_module: TridentModule, outputs: dict, batch: dict, split: Split, dataset_name: str) -> dict
outputs:
_partial_: true
_target_: src.tasks.text_classification.evaluation.get_preds
# takes (trident_module: TridentModule, step_outputs: dict, split: Split, dataset_name: str) -> dict
step_outputs: null
where get_preds is defined as follows and merely adds
def get_preds(outputs: dict, *args, **kwargs) -> dict:
outputs["preds"] = outputs["logits"].argmax(dim=-1)
return outputs
See also
trident.utils.enums.Split, trident.utils.types.PrepareDict
step_outputs¶
step_outputs defines what keys are collected from a batch or outputs dictionary, per step, into the flattened outputs dict per evaluation dataloader. The flattened dictionary then holds the corresponding key-value pairs as input to the prepare_step_outputs function, which ultimately serves at input to metrics computed at the end of an evaluation loop.
Note
trident ensures that after each evaluation loop, lists of np.ndarrays torch.Tensors are correctly stacked to single array with appropriate dimensions.
Example Configuration
# Which keys/attributes are supposed to be collected from `outputs` and `batch`
step_outputs:
# can be a str
batch: labels
# or a list[str]
outputs:
- "preds"
- "logits"
See also
trident.utils.flatten_dict()
metrics¶
metrics denotes a dictionary for all evaluated metrics. For instance, a metric such as acc may contain:
metric: how to instantiate the metric; typically apartialfunction; must return aCallable.compute_on: Eithereval_steporepoch_end, with the latter being the default.kwargs: A custom syntax to fetchkwargsofmetricfrom one of the following:[trident_module, outputs, batch, cfg]. -outputsrefers to the modeloutputswhencompute_onis set toeval_stepand tostep_outputswhencompute_onis set toepoch_end.
- In the NLI example:
The keyword
predsfortorchmetrics.functional.accuracyis sourced fromoutputs["preds"].The keyword
targetfortorchmetrics.functional.accuracyis sourced fromoutputs["labels"].
Example Configuration
metrics:
# name of the metric used eg for logging
acc:
# instructions to instantiate metric, preferrably torchmetrics.Metric
metric:
_partial_: true
_target_: torchmetrics.functional.accuracy
# either "eval_step" or "epoch_end", defaults to "epoch_end"
compute_on: "epoch_end"
kwargs:
preds: "outputs.preds"
target: "outputs.labels"
Note
It is also possible to use metrics for other actions than logging evaluation. The metric function should then return None.
In the below case, we construct a “metric” that logs predictions instead. A metric is not logged to wandb if the function returns None.
def store_predictions(
trident_module: TridentModule,
preds: torch.Tensor,
dataset_name: str,
directory: str,
*args,
**kwargs,
):
trainer = trident_module.trainer
epoch = trainer.current_epoch
p = pd.DataFrame(preds.cpu())
p.columns = ["prediction"]
path = Path(directory).joinpath(f"dataset={dataset_name}_epoch={epoch}.csv")
p.to_csv(str(path), index_label="ids")
And the corresponding metric configuration
metrics:
store_predictions:
metric:
_partial_: true
_target_: $PATH_TO_FUNCTION.store_predictions
directory: ${hydra:runtime.output_dir}
compute_on: "epoch_end"
kwargs:
trident_module: "trident_module"
preds: "outputs.preds"
dataset_name: "dataset_name"
API¶
Properties¶
cfg¶
The cfg: omegaconf.DictConfig holds the dataspec configuration.
See also
dataset¶
The dataset as declared in cfg.dataset and preprocssed by cfg.preprocessing.
evaluation¶
The evaluation as declared in cfg.evaluation.
Methods¶
get_dataloader¶
- TridentDataspec.get_dataloader()[source]
Creates a DataLoader for the dataset.
- Parameters:
signature_columns – Columns to be used in the dataloader. Defaults to None.
passed (If)
misc. (removes unused columns if configured in)
- Returns:
The DataLoader configured as per the specified settings.
- Return type:
DataLoader