trident in 20 minutes¶
The walkthrough first introduces common concepts of hydra and then walks through an exemplary text-classification pipeline for sequence-pair classification (NLI). The example NLI project is embedded in the repository.
hydra primer¶
It is important to have basic familiarity with hydra, which shines at bottom-up hierarchical yaml configuration.
Two key features of hydra for trident are
The defaults-list for hierarchical configuration composition.
Package directives to cleanly combine configurations.
Below is a brief primer of hydra.
Context¶
For trident, we will use hydra to declare our
Trainer: declares how training, validation, and testing is run
TridentModule
: declares the your “model”TridentDatamodule
: declares yourtrain
,val
, andtest
splitsAnd other components like logging and checkpointing
Hierarchical Configuration¶
hydra is a Python yaml framework to compose complex pipelines. The directory tree of your trident configuration in our simplified example may look like the below file tree.
In what follows, we will focus on the hierarchy of the configuration below in the case of dataspec
. A dataspec
defines the dataset, the associated pipelines for preprocessing and evaluation, and the dataloader for the preprocessed dataset.
configs
├── config.yaml
├── dataspec # defines dataset, preprocessing, evaluation, and dataloader
│ ├── dataloader
│ │ ├── default.yaml
│ │ └── train.yaml
│ ├── evaluation
│ │ └── text_classification.yaml
│ ├── preprocessing
│ │ └── shots.yaml
│ ├── default.yaml # general
│ ├── text_classification.yaml # task-specific: inherits general
│ └── nli.yaml # more dataset-specific: inherits task-specific
├── dataspecs # will group a dataspec for a particular benchmark
└── module
├── default.yaml
└── my_module.yaml
The config/dataspec/default.yaml
defines the defaults.
The
null
default ofdataset
reservese the key for a future overridethe
dataloader: default
means that the./configs/dataspec/dataloader.yaml
config will be sourced as the dataloader key for the default dataspecWe define the default
_target_
of a dataset which typically uses the 🤗 datasets library
defaults:
# a null default reserves the key for future override
- dataset: null
# the default dataloader is source into the dataloader key of the dataspec
# i.e. ./configs/dataspec/dataloader.yaml will be sourced in the `dataloader` key of the config
- dataloader: default
# _self_ allows you to control the resolution order of the config itself
# _self_ is not required and appended to the end of the defaults list by default
- _self_
dataset:
# most datasets use the Huggingface
_target_: datasets.load.load_dataset
The configs/dataspec/text_classification.yaml
then extends the default.yaml
defaults:
- default
- evaluation: text_classification
# .. more configuration
The above sources the the configuration of config/dataspec/default.yaml
as well as the task-specific ./config/dataspec/evaluation/text_classification.yaml
into the evaluation
key of the text_classification.yaml
.
Lastly, ./configs/dataspec/nli.yaml
defines even more specific configuration.
defaults:
- text_classification
# ... nli-specific configuration
Notes:
hydra follows an additive configuration paradigm: design your configuration to incrementally add what’s required! Unsetting or removing options often is very unwieldy
The configuration relies on relative paths in the configuration folder structure
In the
defaults
-list, the last one wins (dictionaries are merged sequentially)The own configuration typically comes last (can be controlled via
_self_
, see offficial documentation)
Imporant Special Keywords and Values¶
Beyond the keywords below, ???
denote values in default.yaml
to indicate that the corresponding must be set in a inheriting config.
Key |
Type |
Description |
---|---|---|
|
|
The |
|
|
Do not eagerly instantiate sub-keys that comprises |
|
|
Instantiate |
|
|
Positional arguments for the |
|
|
Only for defaults-lists. You can add _self_ to control the resolution order of the config itself. By default, _self_ is appended to the end. |
Object and Function Instantiation¶
hydra allow you to define Python objects and functions declaratively in yaml. The below example instantiates an AutoModelForSequenceClassification
with roberta-base
.
sequence_classification_model:
# path to the object we want to instantiate
_target_: transformers.AutoModelForSequenceClassification
# kwargs for the object we want to instantiate
# can be themselves instantiated!
pretrained_model_name_or_path: "roberta-base"
num_labels: 3
# optionally positional arguments
# _args_:
# - 0
# - 1
Note
Carefully check whether the parent node of sequence_classification_model
has _recursive_: true
or not! If it is set to false
, the sequence_classification_model
has to be instaniated manually (i.e., hydra.utils.instantiate(cfg.{...}.sequence_classification_model)
hydra yaml configuration comprises various reserved keys:
_target_: transformers.AutoModel.from_pretrained
:_recursive_: false
ensures that objects are not instantiated eagerly, but only when instantiated explicitly. trident takes care of instantiating your objects at the right time to bypass hydra limitations_partial_: true
is common to instantiate functions with pre-set arguments and keywords
Packaging: combine various configs¶
Paired with absolute paths, package directives allow to seamlessly reallocate configuration in the defaults-list.
Imagine you now group a series of dataspecs
for a particular benchmark dataset in configs/dataspecs/xnli_val_test.yaml
.
Note
In the below configuration, the leading /
denotes an absolute config path!
defaults:
- /dataspec/nli@validation_xnli_en
- /dataspec/nli@test_xnli_en
# we can easily add more languages
validation_xnli_en:
dataset:
path: xnli
name: en
split: validation
test_xnli_en:
dataset:
path: xnli
name: en
split: test
The above sources the ./config/dataspec/nli.yaml
into the validation_xnli_en
and test_xnli_en
keys of the xnli_val_test.yaml
group of dataspecs. We can then refine individual configurations for the particular dataspec in the main configuration.
Later on, we can seamlessly include or exclude groups of dataspecs (i.e., benchmarks) like the above.
Project Structure¶
An exemplary structure for a user project is shown below:
src comprises required code, typically for processing and evaluation, as referred to in the config
# yaml configuration
your-project
├── configs
│ ├── config.yaml # inherits all `default.yaml`
│ ├── experiment # typical entry point, 2nd-level `config.yaml` for your experiment
│ │ ├── default.yaml
│ │ └── nli.yaml
│ ├── module
│ │ ├── optimizer # torch.optim
│ │ │ ├── adam.yaml
│ │ │ └── adamw.yaml
│ │ ├── scheduler # learning-rate scheduler
│ │ │ └── linear_warm_up.yaml
│ │ ├── default.yaml
│ │ └── text_classification.yaml
│ ├── datamodule
│ │ ├── default.yaml
│ │ └── mnli_train.yal
│ ├── dataspec # defines [dataset, preprocessing, dataloader, evaluation]
│ │ ├── dataloader
│ │ │ └── default.yaml
│ │ ├── evaluation
│ │ │ └── text_classification.yaml
│ │ │ # inherits dataloader/default.yaml
│ │ ├── default.yaml
│ │ │ # task-specific dataspecs
│ │ │ # inherits default.yaml and evaluation/text_classification.yaml
│ │ ├── text_classification.yaml
│ │ │ # dataset-group specific dataspecs
│ │ │ # inherits text_classification.yaml
│ │ ├── nli.yaml
│ │ │ # dataset-specific dataspecs
│ ├── dataspecs # defines groups of dataspec
│ │ ├── mnli_train.yaml
│ │ ├── xnli_val_test.yaml
│ │ ├── amnli_val_test.yaml
│ │ └── indicxnli_val_test.yaml
│ ├── hydra
│ │ └── default.yaml
│ ├── logger
│ │ ├── csv.yaml
│ │ └── wandb.yaml
│ ├── callbacks # defines callbacks like lightning.pytorch.ModelCheckpoint
│ │ └── default.yaml
│ └── trainer # defines lightning.pytorch.Trainer
│ ├── debug.yaml
│ └── default.yaml
└── src # typical code folder structure
└── tasks
└── text_classification
├── evaluation.py
└── processing.py
Components¶
TridentModule¶
TridentModule
extends the LightningModule. The configuration defines all required components for a TridentModule
:
model
:_target_
to your model constructor for whichTridentModule.model
will be initializedoptimizer
: the optimizer for allTridentModule
parametersscheduler
: the learning-rate scheduler for theoptimizer
The default.yaml
by default sets up AdamW optimizer and linear learning rate scheduler.
# _target_ is hydra-lingo to point to the object (class, function) to instantiate
_target_: trident.TridentModule
# _recursive_: true would mean all keyword arguments are /already/ instantiated
# when passed to `TridentModule`
_recursive_: false
defaults:
# interleaved with setup so instantiated later (recursive false)
- optimizer: adamw.yaml # see config/module/optimizer/adamw.yaml for default
- scheduler: linear_warm_up # see config/module/scheduler/linear_warm_up.yaml for default
# required to be defined by user
model: ???
A common pattern is that users create a configs/module/task.yaml
that predefines shared model
and evaluation
logic for a particular task.
defaults:
- default
- evaluation: text_classification
model:
_target_: transformers.AutoModelForSequenceClassification.from_pretrained
num_labels: ???
pretrained_model_name_or_path: ???
The
model
constructor points totransformers.AutoModelForSequenceClassification.from_pretrained
. The actual model and number of labels will be defined in either the experiment configuration or in the CLI (cf.???
).
TridentDataspec¶
A TridentDataspec
class encapsulates the configuration for data handling in a machine learning workflow. It manages various aspects of data processing including dataset instantiation, preprocessing, dataloading, and evaluation.
Configuration Keys
dataset
: Specifies how the dataset should be instantiated.dataloader
: Defines the instantiation of theDataLoader
.preprocessing
(optional): Details the methods or function calls for dataset preprocessing.evaluation
(optional): Outlines any post-processing steps and metrics for dataset evaluation.misc
(optional): Reserved for miscellaneous settings that do not fit under other keys.
dataset¶
The dataset
instantiates a Dataset that is compatible with a PyTorch DataLoader. Any preprocessing should be defined in the corresponding preprocessing
configuration of the TridentDataspec
.
mnli_train:
dataset: # required, config on how to instantiate dataset
_target_: datasets.load_dataset
path: glue
name: mnli
split: train
preprocessing¶
The preprocessing
key in the configuration details the steps for preparing the dataset. It includes two special keys, method
and apply
, each holding dictionaries for specific preprocessing actions.
method
: Contains dictionaries of class methods along with their keyword arguments. These are typically methods of the dataset class.apply
: Comprises dictionaries of user-defined functions, along with their keyword arguments, to be applied to the dataset. Be mindful that functions ofapply
, unlike most other keys, typically does not take_partial_: true
The preprocessing fucntions take the Dataset
as the first positional argument. The functions are called in order of the configuration. Note that "method"
is a convenience keyword which can also be achieved by pointing to the classmethod in "_target_"
of an "apply"
function.
Example Configuration
preprocessing:
method:
map: # dataset.map of huggingface `datasets.arrow_dataset.Dataset`
function:
_target_: src.tasks.text_classification.processing.preprocess_fn
_partial_: true
column_names:
text: premise
text_pair: hypothesis
tokenizer:
_partial_: true
_target_: transformers.tokenization_utils_base.PreTrainedTokenizerBase.__call__
self:
_target_: transformers.AutoTokenizer.from_pretrained
pretrained_model_name_or_path: ${module.model.pretrained_model_name_or_path}
padding: false
truncation: true
max_length: 128
# unify output format of MNLI and XNLI
set_format:
columns:
- "input_ids"
- "attention_mask"
- "label"
apply:
example_function:
_target_: mod.package.example_function
# ...
dataloader¶
The DataLoader configuration (configs/dataspec/dataloader/default.yaml) is preset with reasonable defaults, accommodating typical use cases.
Example Configuration
_target_: torch.utils.data.dataloader.DataLoader
collate_fn:
_target_: transformers.data.data_collator.DataCollatorWithPadding
tokenizer:
_target_: transformers.AutoTokenizer.from_pretrained
pretrained_model_name_or_path: ${module.model.pretrained_model_name_or_path}
max_length: ???
batch_size: 32
pin_memory: true
shuffle: false
num_workers: 4
evaluation¶
The logic of evaluation is defined in ./configs/dataspec/evaluation/text_classification.yaml
. It is common to define evaluation per type of task.
evaluation
configuration segments into the fields prepare
, step_outputs
, and metrics
.
See also
trident.utils.types.EvaluationDict
prepare¶
prepare
defines functions called on the batch
, the model outputs
, or the collected step_outputs
.
The TridentModule
hands the below keywords to facilitate evaluation. Since the TridentModule
extends the LightningModule, useful attributes like trainer
and trainer.datamodule
are available at runtime.
Example Configuration
prepare:
# takes (trident_module: TridentModule, batch: dict, split: Split, dataset_name: str) -> dict
batch: null
# takes (trident_module: TridentModule, outputs: dict, batch: dict, split: Split, dataset_name: str) -> dict
outputs:
_partial_: true
_target_: src.tasks.text_classification.evaluation.get_preds
# takes (trident_module: TridentModule, step_outputs: dict, split: Split, dataset_name: str) -> dict
step_outputs: null
where get_preds
is defined as follows and merely adds
def get_preds(outputs: dict, *args, **kwargs) -> dict:
outputs["preds"] = outputs["logits"].argmax(dim=-1)
return outputs
See also
trident.utils.enums.Split
, trident.utils.types.PrepareDict
step_outputs¶
step_outputs
defines what keys are collected from a batch
or outputs
dictionary, per step, into the flattened outputs dict
per evaluation dataloader. The flattened dictionary then holds the corresponding key-value pairs as input to the prepare_step_outputs
function, which ultimately serves at input to metrics computed at the end of an evaluation loop.
Note
trident ensures that after each evaluation loop, lists of np.ndarray
s torch.Tensor
s are correctly stacked to single array with appropriate dimensions.
Example Configuration
# Which keys/attributes are supposed to be collected from `outputs` and `batch`
step_outputs:
# can be a str
batch: labels
# or a list[str]
outputs:
- "preds"
- "logits"
See also
trident.utils.flatten_dict()
metrics¶
metrics
denotes a dictionary for all evaluated metrics. For instance, a metric such as acc
may contain:
metric
: how to instantiate the metric; typically apartial
function; must return aCallable
.compute_on
: Eithereval_step
orepoch_end
, with the latter being the default.kwargs
: A custom syntax to fetchkwargs
ofmetric
from one of the following:[trident_module, outputs, batch, cfg]
. -outputs
refers to the modeloutputs
whencompute_on
is set toeval_step
and tostep_outputs
whencompute_on
is set toepoch_end
.
- In the NLI example:
The keyword
preds
fortorchmetrics.functional.accuracy
is sourced fromoutputs["preds"]
.The keyword
target
fortorchmetrics.functional.accuracy
is sourced fromoutputs["labels"]
.
Example Configuration
metrics:
# name of the metric used eg for logging
acc:
# instructions to instantiate metric, preferrably torchmetrics.Metric
metric:
_partial_: true
_target_: torchmetrics.functional.accuracy
# either "eval_step" or "epoch_end", defaults to "epoch_end"
compute_on: "epoch_end"
kwargs:
preds: "outputs.preds"
target: "outputs.labels"
Note
It is also possible to use metrics
for other actions than logging evaluation. The metric function should then return None
.
In the below case, we construct a “metric” that logs predictions instead. A metric is not logged to wandb
if the function returns None
.
def store_predictions(
trident_module: TridentModule,
preds: torch.Tensor,
dataset_name: str,
directory: str,
*args,
**kwargs,
):
trainer = trident_module.trainer
epoch = trainer.current_epoch
p = pd.DataFrame(preds.cpu())
p.columns = ["prediction"]
path = Path(directory).joinpath(f"dataset={dataset_name}_epoch={epoch}.csv")
p.to_csv(str(path), index_label="ids")
And the corresponding metric
configuration
metrics:
store_predictions:
metric:
_partial_: true
_target_: $PATH_TO_FUNCTION.store_predictions
directory: ${hydra:runtime.output_dir}
compute_on: "epoch_end"
kwargs:
trident_module: "trident_module"
preds: "outputs.preds"
dataset_name: "dataset_name"
TridentDataModule¶
The default configuration (configs/datamodule/default.yaml
) for a tridentdatamodule
defines how training and evaluation datasets are instantiated.
Each split is a dictionary of TridentDataspec
.
_target_: trident.TridentDataModule
_recursive_: false
misc:
# reserved key for general TridentDataModule configuration
train:
# DictConfig of TridentDataspec
val:
# DictConfig of TridentDataspec
test:
# DictConfig of TridentDataspec
Config Composition¶
Note
Hierarchical config composition heavily relies on default lists .
The below file tree is a common structure for a hierarchical TridentDatamodule
configuration in our NLI example.
We will hierarchically
Compose a general
dataspec
Compose a tast-specific text classification
dataspec
Compose a NLI
dataspec
Compose a train, val, or test split via
dataspecs
Compose a datamodule
configs
├── config.yaml
├── datamodule
│ └── default.yaml
├── dataspec
│ ├── dataloader
│ │ └── default.yaml
│ ├── evaluation
│ │ └── text_classification.yaml
│ ├── default.yaml
│ ├── nli.yaml
│ └── text_classification.yaml
└── dataspecs
├── mnli_train.yaml
├── xnli_val_test.yaml
└── amnli_val_test.yaml
Default¶
The general dataspec
simply defines the default (./configs/dataspec/default.yaml
) configuration.
defaults:
- dataset: null
# pull in the default dataloader
- dataloader: default
dataset:
#
_target_: datasets.load.load_dataset
Text Classification¶
defaults:
- default
- evaluation: text_classification # see TridentDataspec evaluation
# task specific preprocessing
preprocessing:
... # see TridentDataspec preprocessing
See also
NLI¶
The configs/dataspec/nli.yaml
simply extends the task-specific text_classification.yaml
by specifying columns for the tokenizer in preprocessing.
defaults:
- text_classification
preprocessing:
map:
function:
# column_names denotes input to the tokenizer during preprocessing
column_names:
text: premise
text_pair: hypothesis
Dataspecs¶
We can now compose dataspecs
which group TridentDataspec`
for entire datasets.
The configs/dataspecs/xnli_val_test.yaml
levers hydra
package directives to put the nli
configuration into the corresponding dataspec keys.
defaults:
# package `nli` of configs/dataspec into @{...}
- /dataspec@validation_xnli_en: nli
- /dataspec@validation_xnli_es: nli
# ... can extend this to the entire XNLI benchmark for val and test splits
validation_xnli_en:
dataset:
path: xnli
name: en
split: validation
validation_xnli_es:
dataset:
path: xnli
name: es
split: validation
# ... can extend this to the entire XNLI benchmark for val and test splits
NLI Datamodules¶
Datamodule Configurations¶
We can now use package directives to include the configuration from the configs/dataspecs/xnli_val_test.yaml
file into the val
and test
keys of the TridentDatamodule
.
Warning
When using packaging, make sure to provide a list of dataspecs
configurations to allow for the merging of multiple datamodule
configurations in the experiment
configuration.
- Imporant:
A single
TridentDataspec`
intrain
of theTridentDatamodule
will return abatch
ofdict[str, Any]
at runtimeMultiple
TridentDataspec`
intrain
of theTridentDatamodule
will return abatch
ofdict[str, dict[str, Any]]
for multi-dataset training at runtime
Example Configuration
We now package the config/dataspec/xnli_val_test.yaml
into a list configuration in datamodule.val
of our experiment. We can thereby easily in- and exclude various datasets for training, validation, or testing.
# variant A: training on a single dataset
defaults:
- /dataspecs@datamodule.train: mnli_train
- /dataspecs@datamodule.val:
- xnli_val_test
- amnli_val_test
- indicxnli_val_test
- /dataspecs@datamodule.test:
- xnli_val_test
- amnli_val_test
- indicxnli_val_test
# variant B: training on multiple datasets
defaults:
- /dataspecs@datamodule.train:
- mnli_train
- xnli_train
# ...
Experiment¶
The experiment configurations also segments into a general default.yaml
and a task-specific nli.yaml
.
The run
key is, next to module
, datamodule
, and trainer
a special key reserved for user configuration. The configuration of this key also gets saved in your logger
(e.g., wandb
).
defaults:
- override /trainer: default
- override /callbacks: default
- override /logger: wandb
# `run` namespace should hold your individual configuration
run:
seed: 42
task: ???
trainer:
max_epochs: 10
devices: 1
precision: "16-mixed"
deterministic: true
inference_mode: false
# log vars infers first training dataset
# for logging batch size
_log_vars:
# needed because hydra cannot index list in interpolation
train_datasets: ${oc.dict.keys:datamodule.train}
train_dataset: ${_log_vars.train_datasets[0]}
train_batch_size: ${datamodule.train.${_log_vars.train_dataset}.dataloader.batch_size}
logger:
wandb:
name: "model=${module.model.pretrained_model_name_or_path}_epochs=${trainer.max_epochs}_bs=${_log_vars.train_batch_size}_lr=${module.optimizer.lr}_scheduler=${module.scheduler.num_warmup_steps}_seed=${run.seed}"
tags:
- "${module.model.pretrained_model_name_or_path}"
- "bs=${_log_vars.train_batch_size}"
- "lr=${module.optimizer.lr}"
- "scheduler=${module.scheduler.num_warmup_steps}"
project: ${run.task}
# @package _global_
# The above line is important! It sets the namespace of the config
defaults:
- default
# We can now combine `dataspecs` for training, validation, and testing
- /dataspecs@datamodule.train:
- mnli_train
- /dataspecs@datamodule.val:
- xnli_val_test
- indicxnli_val_test
- amnli_val_test
- override /module: text_classification
run:
task: nli
module:
model:
pretrained_model_name_or_path: "xlm-roberta-base"
num_labels: 3
Commandline Interface¶
hydra allows to simply set configuration items on the commandline. See more information
# change the learning rate
python -m trident.run experiment=nli module.optimizer.lr=0.0001
# set a different optimizer
python -m trident.run experiment=nli module.optimizer=adam
# no lr scheduler
python -m trident.run experiment=nli module.scheduler=null
Warning
The commandline interface only supports absolute paths. For instance, overriding defaults at runtime from the CLI is not possible.