trident in 20 minutes¶
The walkthrough first introduces common concepts of hydra and then walks through an exemplary text-classification pipeline for sequence-pair classification (NLI). The example NLI project is embedded in the repository.
hydra primer¶
It is important to have basic familiarity with hydra, which shines at bottom-up hierarchical yaml configuration.
Two key features of hydra for trident are
The defaults-list for hierarchical configuration composition.
Package directives to cleanly combine configurations.
Below is a brief primer of hydra.
Context¶
For trident, we will use hydra to declare our
Trainer: declares how training, validation, and testing is run
TridentModule: declares the your “model”TridentDatamodule: declares yourtrain,val, andtestsplitsAnd other components like logging and checkpointing
Hierarchical Configuration¶
hydra is a Python yaml framework to compose complex pipelines. The directory tree of your trident configuration in our simplified example may look like the below file tree.
In what follows, we will focus on the hierarchy of the configuration below in the case of dataspec. A dataspec defines the dataset, the associated pipelines for preprocessing and evaluation, and the dataloader for the preprocessed dataset.
configs
├── config.yaml
├── dataspec # defines dataset, preprocessing, evaluation, and dataloader
│ ├── dataloader
│ │ ├── default.yaml
│ │ └── train.yaml
│ ├── evaluation
│ │ └── text_classification.yaml
│ ├── preprocessing
│ │ └── shots.yaml
│ ├── default.yaml # general
│ ├── text_classification.yaml # task-specific: inherits general
│ └── nli.yaml # more dataset-specific: inherits task-specific
├── dataspecs # will group a dataspec for a particular benchmark
└── module
├── default.yaml
└── my_module.yaml
The config/dataspec/default.yaml defines the defaults.
The
nulldefault ofdatasetreservese the key for a future overridethe
dataloader: defaultmeans that the./configs/dataspec/dataloader.yamlconfig will be sourced as the dataloader key for the default dataspecWe define the default
_target_of a dataset which typically uses the 🤗 datasets library
defaults:
# a null default reserves the key for future override
- dataset: null
# the default dataloader is source into the dataloader key of the dataspec
# i.e. ./configs/dataspec/dataloader.yaml will be sourced in the `dataloader` key of the config
- dataloader: default
# _self_ allows you to control the resolution order of the config itself
# _self_ is not required and appended to the end of the defaults list by default
- _self_
dataset:
# most datasets use the Huggingface
_target_: datasets.load.load_dataset
The configs/dataspec/text_classification.yaml then extends the default.yaml
defaults:
- default
- evaluation: text_classification
# .. more configuration
The above sources the the configuration of config/dataspec/default.yaml as well as the task-specific ./config/dataspec/evaluation/text_classification.yaml into the evaluation key of the text_classification.yaml.
Lastly, ./configs/dataspec/nli.yaml defines even more specific configuration.
defaults:
- text_classification
# ... nli-specific configuration
Notes:
hydra follows an additive configuration paradigm: design your configuration to incrementally add what’s required! Unsetting or removing options often is very unwieldy
The configuration relies on relative paths in the configuration folder structure
In the
defaults-list, the last one wins (dictionaries are merged sequentially)The own configuration typically comes last (can be controlled via
_self_, see offficial documentation)
Imporant Special Keywords and Values¶
Beyond the keywords below, ??? denote values in default.yaml to indicate that the corresponding must be set in a inheriting config.
Key |
Type |
Description |
|---|---|---|
|
|
The |
|
|
Do not eagerly instantiate sub-keys that comprises |
|
|
Instantiate |
|
|
Positional arguments for the |
|
|
Only for defaults-lists. You can add _self_ to control the resolution order of the config itself. By default, _self_ is appended to the end. |
Object and Function Instantiation¶
hydra allow you to define Python objects and functions declaratively in yaml. The below example instantiates an AutoModelForSequenceClassification with roberta-base.
sequence_classification_model:
# path to the object we want to instantiate
_target_: transformers.AutoModelForSequenceClassification
# kwargs for the object we want to instantiate
# can be themselves instantiated!
pretrained_model_name_or_path: "roberta-base"
num_labels: 3
# optionally positional arguments
# _args_:
# - 0
# - 1
Note
Carefully check whether the parent node of sequence_classification_model has _recursive_: true or not! If it is set to false, the sequence_classification_model has to be instaniated manually (i.e., hydra.utils.instantiate(cfg.{...}.sequence_classification_model)
hydra yaml configuration comprises various reserved keys:
_target_: transformers.AutoModel.from_pretrained:_recursive_: falseensures that objects are not instantiated eagerly, but only when instantiated explicitly. trident takes care of instantiating your objects at the right time to bypass hydra limitations_partial_: trueis common to instantiate functions with pre-set arguments and keywords
Packaging: combine various configs¶
Paired with absolute paths, package directives allow to seamlessly reallocate configuration in the defaults-list.
Imagine you now group a series of dataspecs for a particular benchmark dataset in configs/dataspecs/xnli_val_test.yaml.
Note
In the below configuration, the leading / denotes an absolute config path!
defaults:
- /dataspec/nli@validation_xnli_en
- /dataspec/nli@test_xnli_en
# we can easily add more languages
validation_xnli_en:
dataset:
path: xnli
name: en
split: validation
test_xnli_en:
dataset:
path: xnli
name: en
split: test
The above sources the ./config/dataspec/nli.yaml into the validation_xnli_en and test_xnli_en keys of the xnli_val_test.yaml group of dataspecs. We can then refine individual configurations for the particular dataspec in the main configuration.
Later on, we can seamlessly include or exclude groups of dataspecs (i.e., benchmarks) like the above.
Project Structure¶
An exemplary structure for a user project is shown below:
src comprises required code, typically for processing and evaluation, as referred to in the config
# yaml configuration
your-project
├── configs
│ ├── config.yaml # inherits all `default.yaml`
│ ├── experiment # typical entry point, 2nd-level `config.yaml` for your experiment
│ │ ├── default.yaml
│ │ └── nli.yaml
│ ├── module
│ │ ├── optimizer # torch.optim
│ │ │ ├── adam.yaml
│ │ │ └── adamw.yaml
│ │ ├── scheduler # learning-rate scheduler
│ │ │ └── linear_warm_up.yaml
│ │ ├── default.yaml
│ │ └── text_classification.yaml
│ ├── datamodule
│ │ ├── default.yaml
│ │ └── mnli_train.yal
│ ├── dataspec # defines [dataset, preprocessing, dataloader, evaluation]
│ │ ├── dataloader
│ │ │ └── default.yaml
│ │ ├── evaluation
│ │ │ └── text_classification.yaml
│ │ │ # inherits dataloader/default.yaml
│ │ ├── default.yaml
│ │ │ # task-specific dataspecs
│ │ │ # inherits default.yaml and evaluation/text_classification.yaml
│ │ ├── text_classification.yaml
│ │ │ # dataset-group specific dataspecs
│ │ │ # inherits text_classification.yaml
│ │ ├── nli.yaml
│ │ │ # dataset-specific dataspecs
│ ├── dataspecs # defines groups of dataspec
│ │ ├── mnli_train.yaml
│ │ ├── xnli_val_test.yaml
│ │ ├── amnli_val_test.yaml
│ │ └── indicxnli_val_test.yaml
│ ├── hydra
│ │ └── default.yaml
│ ├── logger
│ │ ├── csv.yaml
│ │ └── wandb.yaml
│ ├── callbacks # defines callbacks like lightning.pytorch.ModelCheckpoint
│ │ └── default.yaml
│ └── trainer # defines lightning.pytorch.Trainer
│ ├── debug.yaml
│ └── default.yaml
└── src # typical code folder structure
└── tasks
└── text_classification
├── evaluation.py
└── processing.py
Components¶
TridentModule¶
TridentModule extends the LightningModule. The configuration defines all required components for a TridentModule:
model:_target_to your model constructor for whichTridentModule.modelwill be initializedoptimizer: the optimizer for allTridentModuleparametersscheduler: the learning-rate scheduler for theoptimizer
The default.yaml by default sets up AdamW optimizer and linear learning rate scheduler.
# _target_ is hydra-lingo to point to the object (class, function) to instantiate
_target_: trident.TridentModule
# _recursive_: true would mean all keyword arguments are /already/ instantiated
# when passed to `TridentModule`
_recursive_: false
defaults:
# interleaved with setup so instantiated later (recursive false)
- optimizer: adamw.yaml # see config/module/optimizer/adamw.yaml for default
- scheduler: linear_warm_up # see config/module/scheduler/linear_warm_up.yaml for default
# required to be defined by user
model: ???
A common pattern is that users create a configs/module/task.yaml that predefines shared model and evaluation logic for a particular task.
defaults:
- default
- evaluation: text_classification
model:
_target_: transformers.AutoModelForSequenceClassification.from_pretrained
num_labels: ???
pretrained_model_name_or_path: ???
The
modelconstructor points totransformers.AutoModelForSequenceClassification.from_pretrained. The actual model and number of labels will be defined in either the experiment configuration or in the CLI (cf.???).
TridentDataspec¶
A TridentDataspec class encapsulates the configuration for data handling in a machine learning workflow. It manages various aspects of data processing including dataset instantiation, preprocessing, dataloading, and evaluation.
Configuration Keys
dataset: Specifies how the dataset should be instantiated.dataloader: Defines the instantiation of theDataLoader.preprocessing(optional): Details the methods or function calls for dataset preprocessing.evaluation(optional): Outlines any post-processing steps and metrics for dataset evaluation.misc(optional): Reserved for miscellaneous settings that do not fit under other keys.
dataset¶
The dataset instantiates a Dataset that is compatible with a PyTorch DataLoader. Any preprocessing should be defined in the corresponding preprocessing configuration of the TridentDataspec.
mnli_train:
dataset: # required, config on how to instantiate dataset
_target_: datasets.load_dataset
path: glue
name: mnli
split: train
preprocessing¶
The preprocessing key in the configuration details the steps for preparing the dataset. It includes two special keys, method and apply, each holding dictionaries for specific preprocessing actions.
method: Contains dictionaries of class methods along with their keyword arguments. These are typically methods of the dataset class.apply: Comprises dictionaries of user-defined functions, along with their keyword arguments, to be applied to the dataset. Be mindful that functions ofapply, unlike most other keys, typically does not take_partial_: true
The preprocessing fucntions take the Dataset as the first positional argument. The functions are called in order of the configuration. Note that "method" is a convenience keyword which can also be achieved by pointing to the classmethod in "_target_" of an "apply" function.
Example Configuration
preprocessing:
method:
map: # dataset.map of huggingface `datasets.arrow_dataset.Dataset`
function:
_target_: src.tasks.text_classification.processing.preprocess_fn
_partial_: true
column_names:
text: premise
text_pair: hypothesis
tokenizer:
_partial_: true
_target_: transformers.tokenization_utils_base.PreTrainedTokenizerBase.__call__
self:
_target_: transformers.AutoTokenizer.from_pretrained
pretrained_model_name_or_path: ${module.model.pretrained_model_name_or_path}
padding: false
truncation: true
max_length: 128
# unify output format of MNLI and XNLI
set_format:
columns:
- "input_ids"
- "attention_mask"
- "label"
apply:
example_function:
_target_: mod.package.example_function
# ...
dataloader¶
The DataLoader configuration (configs/dataspec/dataloader/default.yaml) is preset with reasonable defaults, accommodating typical use cases.
Example Configuration
_target_: torch.utils.data.dataloader.DataLoader
collate_fn:
_target_: transformers.data.data_collator.DataCollatorWithPadding
tokenizer:
_target_: transformers.AutoTokenizer.from_pretrained
pretrained_model_name_or_path: ${module.model.pretrained_model_name_or_path}
max_length: ???
batch_size: 32
pin_memory: true
shuffle: false
num_workers: 4
evaluation¶
The logic of evaluation is defined in ./configs/dataspec/evaluation/text_classification.yaml. It is common to define evaluation per type of task.
evaluation configuration segments into the fields prepare, step_outputs, and metrics.
See also
trident.utils.types.EvaluationDict
prepare¶
prepare defines functions called on the batch, the model outputs, or the collected step_outputs.
The TridentModule hands the below keywords to facilitate evaluation. Since the TridentModule extends the LightningModule, useful attributes like trainer and trainer.datamodule are available at runtime.
Example Configuration
prepare:
# takes (trident_module: TridentModule, batch: dict, split: Split, dataset_name: str) -> dict
batch: null
# takes (trident_module: TridentModule, outputs: dict, batch: dict, split: Split, dataset_name: str) -> dict
outputs:
_partial_: true
_target_: src.tasks.text_classification.evaluation.get_preds
# takes (trident_module: TridentModule, step_outputs: dict, split: Split, dataset_name: str) -> dict
step_outputs: null
where get_preds is defined as follows and merely adds
def get_preds(outputs: dict, *args, **kwargs) -> dict:
outputs["preds"] = outputs["logits"].argmax(dim=-1)
return outputs
See also
trident.utils.enums.Split, trident.utils.types.PrepareDict
step_outputs¶
step_outputs defines what keys are collected from a batch or outputs dictionary, per step, into the flattened outputs dict per evaluation dataloader. The flattened dictionary then holds the corresponding key-value pairs as input to the prepare_step_outputs function, which ultimately serves at input to metrics computed at the end of an evaluation loop.
Note
trident ensures that after each evaluation loop, lists of np.ndarrays torch.Tensors are correctly stacked to single array with appropriate dimensions.
Example Configuration
# Which keys/attributes are supposed to be collected from `outputs` and `batch`
step_outputs:
# can be a str
batch: labels
# or a list[str]
outputs:
- "preds"
- "logits"
See also
trident.utils.flatten_dict()
metrics¶
metrics denotes a dictionary for all evaluated metrics. For instance, a metric such as acc may contain:
metric: how to instantiate the metric; typically apartialfunction; must return aCallable.compute_on: Eithereval_steporepoch_end, with the latter being the default.kwargs: A custom syntax to fetchkwargsofmetricfrom one of the following:[trident_module, outputs, batch, cfg]. -outputsrefers to the modeloutputswhencompute_onis set toeval_stepand tostep_outputswhencompute_onis set toepoch_end.
- In the NLI example:
The keyword
predsfortorchmetrics.functional.accuracyis sourced fromoutputs["preds"].The keyword
targetfortorchmetrics.functional.accuracyis sourced fromoutputs["labels"].
Example Configuration
metrics:
# name of the metric used eg for logging
acc:
# instructions to instantiate metric, preferrably torchmetrics.Metric
metric:
_partial_: true
_target_: torchmetrics.functional.accuracy
# either "eval_step" or "epoch_end", defaults to "epoch_end"
compute_on: "epoch_end"
kwargs:
preds: "outputs.preds"
target: "outputs.labels"
Note
It is also possible to use metrics for other actions than logging evaluation. The metric function should then return None.
In the below case, we construct a “metric” that logs predictions instead. A metric is not logged to wandb if the function returns None.
def store_predictions(
trident_module: TridentModule,
preds: torch.Tensor,
dataset_name: str,
directory: str,
*args,
**kwargs,
):
trainer = trident_module.trainer
epoch = trainer.current_epoch
p = pd.DataFrame(preds.cpu())
p.columns = ["prediction"]
path = Path(directory).joinpath(f"dataset={dataset_name}_epoch={epoch}.csv")
p.to_csv(str(path), index_label="ids")
And the corresponding metric configuration
metrics:
store_predictions:
metric:
_partial_: true
_target_: $PATH_TO_FUNCTION.store_predictions
directory: ${hydra:runtime.output_dir}
compute_on: "epoch_end"
kwargs:
trident_module: "trident_module"
preds: "outputs.preds"
dataset_name: "dataset_name"
TridentDataModule¶
The default configuration (configs/datamodule/default.yaml) for a tridentdatamodule defines how training and evaluation datasets are instantiated.
Each split is a dictionary of TridentDataspec.
_target_: trident.TridentDataModule
_recursive_: false
misc:
# reserved key for general TridentDataModule configuration
train:
# DictConfig of TridentDataspec
val:
# DictConfig of TridentDataspec
test:
# DictConfig of TridentDataspec
Config Composition¶
Note
Hierarchical config composition heavily relies on default lists .
The below file tree is a common structure for a hierarchical TridentDatamodule configuration in our NLI example.
We will hierarchically
Compose a general
dataspecCompose a tast-specific text classification
dataspecCompose a NLI
dataspecCompose a train, val, or test split via
dataspecsCompose a datamodule
configs
├── config.yaml
├── datamodule
│ └── default.yaml
├── dataspec
│ ├── dataloader
│ │ └── default.yaml
│ ├── evaluation
│ │ └── text_classification.yaml
│ ├── default.yaml
│ ├── nli.yaml
│ └── text_classification.yaml
└── dataspecs
├── mnli_train.yaml
├── xnli_val_test.yaml
└── amnli_val_test.yaml
Default¶
The general dataspec simply defines the default (./configs/dataspec/default.yaml) configuration.
defaults:
- dataset: null
# pull in the default dataloader
- dataloader: default
dataset:
#
_target_: datasets.load.load_dataset
Text Classification¶
defaults:
- default
- evaluation: text_classification # see TridentDataspec evaluation
# task specific preprocessing
preprocessing:
... # see TridentDataspec preprocessing
See also
NLI¶
The configs/dataspec/nli.yaml simply extends the task-specific text_classification.yaml by specifying columns for the tokenizer in preprocessing.
defaults:
- text_classification
preprocessing:
map:
function:
# column_names denotes input to the tokenizer during preprocessing
column_names:
text: premise
text_pair: hypothesis
Dataspecs¶
We can now compose dataspecs which group TridentDataspec` for entire datasets.
The configs/dataspecs/xnli_val_test.yaml levers hydra package directives to put the nli configuration into the corresponding dataspec keys.
defaults:
# package `nli` of configs/dataspec into @{...}
- /dataspec@validation_xnli_en: nli
- /dataspec@validation_xnli_es: nli
# ... can extend this to the entire XNLI benchmark for val and test splits
validation_xnli_en:
dataset:
path: xnli
name: en
split: validation
validation_xnli_es:
dataset:
path: xnli
name: es
split: validation
# ... can extend this to the entire XNLI benchmark for val and test splits
NLI Datamodules¶
Datamodule Configurations¶
We can now use package directives to include the configuration from the configs/dataspecs/xnli_val_test.yaml file into the val and test keys of the TridentDatamodule.
Warning
When using packaging, make sure to provide a list of dataspecs configurations to allow for the merging of multiple datamodule configurations in the experiment configuration.
- Imporant:
A single
TridentDataspec`intrainof theTridentDatamodulewill return abatchofdict[str, Any]at runtimeMultiple
TridentDataspec`intrainof theTridentDatamodulewill return abatchofdict[str, dict[str, Any]]for multi-dataset training at runtime
Example Configuration
We now package the config/dataspec/xnli_val_test.yaml into a list configuration in datamodule.val of our experiment. We can thereby easily in- and exclude various datasets for training, validation, or testing.
# variant A: training on a single dataset
defaults:
- /dataspecs@datamodule.train: mnli_train
- /dataspecs@datamodule.val:
- xnli_val_test
- amnli_val_test
- indicxnli_val_test
- /dataspecs@datamodule.test:
- xnli_val_test
- amnli_val_test
- indicxnli_val_test
# variant B: training on multiple datasets
defaults:
- /dataspecs@datamodule.train:
- mnli_train
- xnli_train
# ...
Experiment¶
The experiment configurations also segments into a general default.yaml and a task-specific nli.yaml.
The run key is, next to module, datamodule, and trainer a special key reserved for user configuration. The configuration of this key also gets saved in your logger (e.g., wandb).
defaults:
- override /trainer: default
- override /callbacks: default
- override /logger: wandb
# `run` namespace should hold your individual configuration
run:
seed: 42
task: ???
trainer:
max_epochs: 10
devices: 1
precision: "16-mixed"
deterministic: true
inference_mode: false
# log vars infers first training dataset
# for logging batch size
_log_vars:
# needed because hydra cannot index list in interpolation
train_datasets: ${oc.dict.keys:datamodule.train}
train_dataset: ${_log_vars.train_datasets[0]}
train_batch_size: ${datamodule.train.${_log_vars.train_dataset}.dataloader.batch_size}
logger:
wandb:
name: "model=${module.model.pretrained_model_name_or_path}_epochs=${trainer.max_epochs}_bs=${_log_vars.train_batch_size}_lr=${module.optimizer.lr}_scheduler=${module.scheduler.num_warmup_steps}_seed=${run.seed}"
tags:
- "${module.model.pretrained_model_name_or_path}"
- "bs=${_log_vars.train_batch_size}"
- "lr=${module.optimizer.lr}"
- "scheduler=${module.scheduler.num_warmup_steps}"
project: ${run.task}
# @package _global_
# The above line is important! It sets the namespace of the config
defaults:
- default
# We can now combine `dataspecs` for training, validation, and testing
- /dataspecs@datamodule.train:
- mnli_train
- /dataspecs@datamodule.val:
- xnli_val_test
- indicxnli_val_test
- amnli_val_test
- override /module: text_classification
run:
task: nli
module:
model:
pretrained_model_name_or_path: "xlm-roberta-base"
num_labels: 3
Commandline Interface¶
hydra allows to simply set configuration items on the commandline. See more information
# change the learning rate
python -m trident.run experiment=nli module.optimizer.lr=0.0001
# set a different optimizer
python -m trident.run experiment=nli module.optimizer=adam
# no lr scheduler
python -m trident.run experiment=nli module.scheduler=null
Warning
The commandline interface only supports absolute paths. For instance, overriding defaults at runtime from the CLI is not possible.