trident in 20 minutes

The walkthrough first introduces common concepts of hydra and then walks through an exemplary text-classification pipeline for sequence-pair classification (NLI). The example NLI project is embedded in the repository.

hydra primer

It is important to have basic familiarity with hydra, which shines at bottom-up hierarchical yaml configuration.

Two key features of hydra for trident are

  1. The defaults-list for hierarchical configuration composition.

  2. Package directives to cleanly combine configurations.

Below is a brief primer of hydra.

Context

For trident, we will use hydra to declare our

  • Trainer: declares how training, validation, and testing is run

  • TridentModule: declares the your “model”

  • TridentDatamodule: declares your train, val, and test splits

  • And other components like logging and checkpointing

Hierarchical Configuration

hydra is a Python yaml framework to compose complex pipelines. The directory tree of your trident configuration in our simplified example may look like the below file tree.

In what follows, we will focus on the hierarchy of the configuration below in the case of dataspec. A dataspec defines the dataset, the associated pipelines for preprocessing and evaluation, and the dataloader for the preprocessed dataset.

configs
├── config.yaml
├── dataspec # defines dataset, preprocessing, evaluation, and dataloader   ├── dataloader
│      ├── default.yaml
│      └── train.yaml
│   ├── evaluation
│      └── text_classification.yaml
│   ├── preprocessing
│      └── shots.yaml
│   ├── default.yaml               # general   ├── text_classification.yaml   # task-specific: inherits general   └── nli.yaml                   # more dataset-specific: inherits task-specific
├── dataspecs # will group a dataspec for a particular benchmark
└── module
    ├── default.yaml
    └── my_module.yaml

The config/dataspec/default.yaml defines the defaults.

  1. The null default of dataset reservese the key for a future override

  2. the dataloader: default means that the ./configs/dataspec/dataloader.yaml config will be sourced as the dataloader key for the default dataspec

  3. We define the default _target_ of a dataset which typically uses the 🤗 datasets library

defaults:
  # a null default reserves the key for future override
  - dataset: null
  # the default dataloader is source into the dataloader key of the dataspec
  # i.e. ./configs/dataspec/dataloader.yaml will be sourced in the `dataloader` key of the config
  - dataloader: default
  # _self_ allows you to control the resolution order of the config itself
  # _self_ is not required and appended to the end of the defaults list by default
  - _self_

dataset:
  # most datasets use the Huggingface
  _target_: datasets.load.load_dataset

The configs/dataspec/text_classification.yaml then extends the default.yaml

defaults:
  - default
  - evaluation: text_classification

# .. more configuration

The above sources the the configuration of config/dataspec/default.yaml as well as the task-specific ./config/dataspec/evaluation/text_classification.yaml into the evaluation key of the text_classification.yaml.

Lastly, ./configs/dataspec/nli.yaml defines even more specific configuration.

defaults:
  - text_classification

# ... nli-specific configuration

Notes:

  • hydra follows an additive configuration paradigm: design your configuration to incrementally add what’s required! Unsetting or removing options often is very unwieldy

  • The configuration relies on relative paths in the configuration folder structure

  • In the defaults-list, the last one wins (dictionaries are merged sequentially)

  • The own configuration typically comes last (can be controlled via _self_, see offficial documentation)

Imporant Special Keywords and Values

Beyond the keywords below, ??? denote values in default.yaml to indicate that the corresponding must be set in a inheriting config.

Special Hydra Syntax

Key

Type

Description

_target_

str

The _target_ points to the Python function / method that initializes the object

_recursive_

bool

Do not eagerly instantiate sub-keys that comprises _target_, but only when hydra.utils.instantiate is called

_partial_

bool

Instantiate _target_ function curried with pre-set arguments and keywords

_args_

list[Any]

Positional arguments for the __target__ at hydra.utils.instantiate

_self_

str

Only for defaults-lists. You can add _self_ to control the resolution order of the config itself. By default, _self_ is appended to the end.

Object and Function Instantiation

hydra allow you to define Python objects and functions declaratively in yaml. The below example instantiates an AutoModelForSequenceClassification with roberta-base.

sequence_classification_model:
    # path to the object we want to instantiate
    _target_: transformers.AutoModelForSequenceClassification
    # kwargs for the object we want to instantiate
    # can be themselves instantiated!
    pretrained_model_name_or_path: "roberta-base"
    num_labels: 3
    # optionally positional arguments
    # _args_:
    #  - 0
    #  - 1

Note

Carefully check whether the parent node of sequence_classification_model has _recursive_: true or not! If it is set to false, the sequence_classification_model has to be instaniated manually (i.e., hydra.utils.instantiate(cfg.{...}.sequence_classification_model)

hydra yaml configuration comprises various reserved keys:

  • _target_: transformers.AutoModel.from_pretrained:

  • _recursive_: false ensures that objects are not instantiated eagerly, but only when instantiated explicitly. trident takes care of instantiating your objects at the right time to bypass hydra limitations

  • _partial_: true is common to instantiate functions with pre-set arguments and keywords

Packaging: combine various configs

Paired with absolute paths, package directives allow to seamlessly reallocate configuration in the defaults-list.

Imagine you now group a series of dataspecs for a particular benchmark dataset in configs/dataspecs/xnli_val_test.yaml.

Note

In the below configuration, the leading / denotes an absolute config path!

defaults:
  - /dataspec/nli@validation_xnli_en
  - /dataspec/nli@test_xnli_en
  # we can easily add more languages

validation_xnli_en:
  dataset:
    path: xnli
    name: en
    split: validation
test_xnli_en:
  dataset:
    path: xnli
    name: en
    split: test

The above sources the ./config/dataspec/nli.yaml into the validation_xnli_en and test_xnli_en keys of the xnli_val_test.yaml group of dataspecs. We can then refine individual configurations for the particular dataspec in the main configuration.

Later on, we can seamlessly include or exclude groups of dataspecs (i.e., benchmarks) like the above.

Project Structure

An exemplary structure for a user project is shown below:

  • configs holds the entire hydra yaml configuration

  • src comprises required code, typically for processing and evaluation, as referred to in the config

# yaml configuration
your-project
├── configs
│   ├── config.yaml # inherits all `default.yaml`   ├── experiment # typical entry point, 2nd-level `config.yaml` for your experiment      ├── default.yaml
│      └── nli.yaml
│   ├── module
│      ├── optimizer # torch.optim         ├── adam.yaml
│         └── adamw.yaml
│      ├── scheduler # learning-rate scheduler         └── linear_warm_up.yaml
│      ├── default.yaml
│      └── text_classification.yaml
│   ├── datamodule
│      ├── default.yaml
│      └── mnli_train.yal
│   ├── dataspec # defines [dataset, preprocessing, dataloader, evaluation]      ├── dataloader
│         └── default.yaml
│      ├── evaluation
│         └── text_classification.yaml
│       # inherits dataloader/default.yaml      ├── default.yaml
│       # task-specific dataspecs       # inherits default.yaml and evaluation/text_classification.yaml      ├── text_classification.yaml
│       # dataset-group specific dataspecs       # inherits text_classification.yaml      ├── nli.yaml
│       # dataset-specific dataspecs   ├── dataspecs # defines groups of dataspec      ├── mnli_train.yaml
│      ├── xnli_val_test.yaml
│      ├── amnli_val_test.yaml
│      └── indicxnli_val_test.yaml
│   ├── hydra
│      └── default.yaml
│   ├── logger
│      ├── csv.yaml
│      └── wandb.yaml
│   ├── callbacks # defines callbacks like lightning.pytorch.ModelCheckpoint      └── default.yaml
│   └── trainer # defines lightning.pytorch.Trainer       ├── debug.yaml
│       └── default.yaml
└── src # typical code folder structure
    └── tasks
        └── text_classification
            ├── evaluation.py
            └── processing.py

Components

TridentModule

TridentModule extends the LightningModule. The configuration defines all required components for a TridentModule:

  1. model: _target_ to your model constructor for which TridentModule.model will be initialized

  2. optimizer: the optimizer for all TridentModule parameters

  3. scheduler: the learning-rate scheduler for the optimizer

The default.yaml by default sets up AdamW optimizer and linear learning rate scheduler.

# _target_ is hydra-lingo to point to the object (class, function) to instantiate
_target_: trident.TridentModule
# _recursive_: true would mean all keyword arguments are /already/ instantiated
# when passed to `TridentModule`
_recursive_: false

defaults:
# interleaved with setup so instantiated later (recursive false)
- optimizer: adamw.yaml  # see config/module/optimizer/adamw.yaml for default
- scheduler: linear_warm_up  # see config/module/scheduler/linear_warm_up.yaml for default

# required to be defined by user
model: ???

A common pattern is that users create a configs/module/task.yaml that predefines shared model and evaluation logic for a particular task.

defaults:
  - default
  - evaluation: text_classification
model:
  _target_: transformers.AutoModelForSequenceClassification.from_pretrained
  num_labels: ???
  pretrained_model_name_or_path: ???
  • The model constructor points to transformers.AutoModelForSequenceClassification.from_pretrained. The actual model and number of labels will be defined in either the experiment configuration or in the CLI (cf. ???).

TridentDataspec

A TridentDataspec class encapsulates the configuration for data handling in a machine learning workflow. It manages various aspects of data processing including dataset instantiation, preprocessing, dataloading, and evaluation.

Configuration Keys

  • dataset: Specifies how the dataset should be instantiated.

  • dataloader: Defines the instantiation of the DataLoader.

  • preprocessing (optional): Details the methods or function calls for dataset preprocessing.

  • evaluation (optional): Outlines any post-processing steps and metrics for dataset evaluation.

  • misc (optional): Reserved for miscellaneous settings that do not fit under other keys.

dataset

The dataset instantiates a Dataset that is compatible with a PyTorch DataLoader. Any preprocessing should be defined in the corresponding preprocessing configuration of the TridentDataspec.

mnli_train:
  dataset: # required, config on how to instantiate dataset
    _target_: datasets.load_dataset
    path: glue
    name: mnli
    split: train

preprocessing

The preprocessing key in the configuration details the steps for preparing the dataset. It includes two special keys, method and apply, each holding dictionaries for specific preprocessing actions.

  • method: Contains dictionaries of class methods along with their keyword arguments. These are typically methods of the dataset class.

  • apply: Comprises dictionaries of user-defined functions, along with their keyword arguments, to be applied to the dataset. Be mindful that functions of apply, unlike most other keys, typically does not take _partial_: true

The preprocessing fucntions take the Dataset as the first positional argument. The functions are called in order of the configuration. Note that "method" is a convenience keyword which can also be achieved by pointing to the classmethod in "_target_" of an "apply" function.

Example Configuration

preprocessing:
  method:
    map: # dataset.map of huggingface `datasets.arrow_dataset.Dataset`
      function:
        _target_: src.tasks.text_classification.processing.preprocess_fn
        _partial_: true
        column_names:
          text: premise
          text_pair: hypothesis
        tokenizer:
          _partial_: true
          _target_: transformers.tokenization_utils_base.PreTrainedTokenizerBase.__call__
          self:
              _target_: transformers.AutoTokenizer.from_pretrained
              pretrained_model_name_or_path: ${module.model.pretrained_model_name_or_path}
          padding: false
          truncation: true
          max_length: 128
    # unify output format of MNLI and XNLI
    set_format:
      columns:
        - "input_ids"
        - "attention_mask"
        - "label"
  apply:
    example_function:
      _target_: mod.package.example_function
      # ...

dataloader

The DataLoader configuration (configs/dataspec/dataloader/default.yaml) is preset with reasonable defaults, accommodating typical use cases.

Example Configuration

_target_: torch.utils.data.dataloader.DataLoader
collate_fn:
  _target_: transformers.data.data_collator.DataCollatorWithPadding
  tokenizer:
    _target_: transformers.AutoTokenizer.from_pretrained
    pretrained_model_name_or_path: ${module.model.pretrained_model_name_or_path}
  max_length: ???
batch_size: 32
pin_memory: true
shuffle: false
num_workers: 4

evaluation

The logic of evaluation is defined in ./configs/dataspec/evaluation/text_classification.yaml. It is common to define evaluation per type of task.

evaluation configuration segments into the fields prepare, step_outputs, and metrics.

See also

trident.utils.types.EvaluationDict

prepare

prepare defines functions called on the batch, the model outputs, or the collected step_outputs.

The TridentModule hands the below keywords to facilitate evaluation. Since the TridentModule extends the LightningModule, useful attributes like trainer and trainer.datamodule are available at runtime.

Example Configuration

prepare:
  # takes (trident_module: TridentModule, batch: dict, split: Split, dataset_name: str) -> dict
  batch: null
  # takes (trident_module: TridentModule, outputs: dict, batch: dict, split: Split, dataset_name: str) -> dict
  outputs:
    _partial_: true
    _target_: src.tasks.text_classification.evaluation.get_preds
  # takes (trident_module: TridentModule, step_outputs: dict, split: Split, dataset_name: str) -> dict
  step_outputs: null

where get_preds is defined as follows and merely adds

def get_preds(outputs: dict, *args, **kwargs) -> dict:
    outputs["preds"] = outputs["logits"].argmax(dim=-1)
    return outputs

See also

trident.utils.enums.Split, trident.utils.types.PrepareDict

step_outputs

step_outputs defines what keys are collected from a batch or outputs dictionary, per step, into the flattened outputs dict per evaluation dataloader. The flattened dictionary then holds the corresponding key-value pairs as input to the prepare_step_outputs function, which ultimately serves at input to metrics computed at the end of an evaluation loop.

Note

trident ensures that after each evaluation loop, lists of np.ndarrays torch.Tensors are correctly stacked to single array with appropriate dimensions.

Example Configuration

# Which keys/attributes are supposed to be collected from `outputs` and `batch`
step_outputs:
  # can be a str
  batch: labels
  # or a list[str]
  outputs:
    - "preds"
    - "logits"

See also

trident.utils.flatten_dict()

metrics

metrics denotes a dictionary for all evaluated metrics. For instance, a metric such as acc may contain:

  • metric: how to instantiate the metric; typically a partial function; must return a Callable.

  • compute_on: Either eval_step or epoch_end, with the latter being the default.

  • kwargs: A custom syntax to fetch kwargs of metric from one of the following: [trident_module, outputs, batch, cfg]. - outputs refers to the model outputs when compute_on is set to eval_step and to step_outputs when compute_on is set to epoch_end.

In the NLI example:
  • The keyword preds for torchmetrics.functional.accuracy is sourced from outputs["preds"].

  • The keyword target for torchmetrics.functional.accuracy is sourced from outputs["labels"].

Example Configuration

metrics:
  # name of the metric used eg for logging
  acc:
    # instructions to instantiate metric, preferrably torchmetrics.Metric
    metric:
      _partial_: true
      _target_: torchmetrics.functional.accuracy
    # either "eval_step" or "epoch_end", defaults to "epoch_end"
    compute_on: "epoch_end"
    kwargs:
      preds: "outputs.preds"
      target: "outputs.labels"

Note

It is also possible to use metrics for other actions than logging evaluation. The metric function should then return None.

In the below case, we construct a “metric” that logs predictions instead. A metric is not logged to wandb if the function returns None.

def store_predictions(
    trident_module: TridentModule,
    preds: torch.Tensor,
    dataset_name: str,
    directory: str,
    *args,
    **kwargs,
):
    trainer = trident_module.trainer
    epoch = trainer.current_epoch
    p = pd.DataFrame(preds.cpu())
    p.columns = ["prediction"]
    path = Path(directory).joinpath(f"dataset={dataset_name}_epoch={epoch}.csv")
    p.to_csv(str(path), index_label="ids")

And the corresponding metric configuration

metrics:
  store_predictions:
    metric:
      _partial_: true
      _target_: $PATH_TO_FUNCTION.store_predictions
      directory: ${hydra:runtime.output_dir}
    compute_on: "epoch_end"
    kwargs:
      trident_module: "trident_module"
      preds: "outputs.preds"
      dataset_name: "dataset_name"

TridentDataModule

The default configuration (configs/datamodule/default.yaml) for a tridentdatamodule defines how training and evaluation datasets are instantiated. Each split is a dictionary of TridentDataspec.

_target_: trident.TridentDataModule
_recursive_: false

misc:
    # reserved key for general TridentDataModule configuration
train:
    # DictConfig of TridentDataspec
val:
    # DictConfig of TridentDataspec
test:
    # DictConfig of TridentDataspec

Config Composition

Note

Hierarchical config composition heavily relies on default lists .

The below file tree is a common structure for a hierarchical TridentDatamodule configuration in our NLI example.

We will hierarchically

  1. Compose a general dataspec

  2. Compose a tast-specific text classification dataspec

  3. Compose a NLI dataspec

  4. Compose a train, val, or test split via dataspecs

  5. Compose a datamodule

configs
├── config.yaml
├── datamodule
│   └── default.yaml
├── dataspec
│   ├── dataloader
│      └── default.yaml
│   ├── evaluation
│      └── text_classification.yaml
│   ├── default.yaml
│   ├── nli.yaml
│   └── text_classification.yaml
└── dataspecs
    ├── mnli_train.yaml
    ├── xnli_val_test.yaml
    └── amnli_val_test.yaml
Default

The general dataspec simply defines the default (./configs/dataspec/default.yaml) configuration.

defaults:
  - dataset: null
  # pull in the default dataloader
  - dataloader: default

dataset:
  #
  _target_: datasets.load.load_dataset
Text Classification
defaults:
  - default
  - evaluation: text_classification # see TridentDataspec evaluation

# task specific preprocessing
preprocessing:
    ... # see TridentDataspec preprocessing
NLI

The configs/dataspec/nli.yaml simply extends the task-specific text_classification.yaml by specifying columns for the tokenizer in preprocessing.

defaults:
  - text_classification

preprocessing:
  map:
    function:
      # column_names denotes input to the tokenizer during preprocessing
      column_names:
        text: premise
        text_pair: hypothesis
Dataspecs

We can now compose dataspecs which group TridentDataspec` for entire datasets.

The configs/dataspecs/xnli_val_test.yaml levers hydra package directives to put the nli configuration into the corresponding dataspec keys.

defaults:
  # package `nli` of configs/dataspec into @{...}
  - /dataspec@validation_xnli_en: nli
  - /dataspec@validation_xnli_es: nli
  # ... can extend this to the entire XNLI benchmark for val and test splits
validation_xnli_en:
  dataset:
    path: xnli
    name: en
    split: validation
validation_xnli_es:
  dataset:
    path: xnli
    name: es
    split: validation
# ... can extend this to the entire XNLI benchmark for val and test splits
NLI Datamodules
Datamodule Configurations

We can now use package directives to include the configuration from the configs/dataspecs/xnli_val_test.yaml file into the val and test keys of the TridentDatamodule.

Warning

When using packaging, make sure to provide a list of dataspecs configurations to allow for the merging of multiple datamodule configurations in the experiment configuration.

Imporant:
  • A single TridentDataspec` in train of the TridentDatamodule will return a batch of dict[str, Any] at runtime

  • Multiple TridentDataspec` in train of the TridentDatamodule will return a batch of dict[str, dict[str, Any]] for multi-dataset training at runtime

Example Configuration

We now package the config/dataspec/xnli_val_test.yaml into a list configuration in datamodule.val of our experiment. We can thereby easily in- and exclude various datasets for training, validation, or testing.

# variant A: training on a single dataset
defaults:
  - /dataspecs@datamodule.train: mnli_train
  - /dataspecs@datamodule.val:
    - xnli_val_test
    - amnli_val_test
    - indicxnli_val_test
  - /dataspecs@datamodule.test:
    - xnli_val_test
    - amnli_val_test
    - indicxnli_val_test
# variant B: training on multiple datasets
defaults:
  - /dataspecs@datamodule.train:
    - mnli_train
    - xnli_train
# ...

Experiment

The experiment configurations also segments into a general default.yaml and a task-specific nli.yaml.

The run key is, next to module, datamodule, and trainer a special key reserved for user configuration. The configuration of this key also gets saved in your logger (e.g., wandb).

defaults:
  - override /trainer: default
  - override /callbacks: default
  - override /logger: wandb

# `run` namespace should hold your individual configuration
run:
  seed: 42
  task: ???

trainer:
  max_epochs: 10
  devices: 1
  precision: "16-mixed"
  deterministic: true
  inference_mode: false

# log vars infers first training dataset
# for logging batch size
_log_vars:
  # needed because hydra cannot index list in interpolation
  train_datasets: ${oc.dict.keys:datamodule.train}
  train_dataset: ${_log_vars.train_datasets[0]}
  train_batch_size: ${datamodule.train.${_log_vars.train_dataset}.dataloader.batch_size}

logger:
  wandb:
    name: "model=${module.model.pretrained_model_name_or_path}_epochs=${trainer.max_epochs}_bs=${_log_vars.train_batch_size}_lr=${module.optimizer.lr}_scheduler=${module.scheduler.num_warmup_steps}_seed=${run.seed}"
    tags:
      - "${module.model.pretrained_model_name_or_path}"
      - "bs=${_log_vars.train_batch_size}"
      - "lr=${module.optimizer.lr}"
      - "scheduler=${module.scheduler.num_warmup_steps}"
    project: ${run.task}
# @package _global_
# The above line is important! It sets the namespace of the config

defaults:
  - default
  # We can now combine `dataspecs` for training, validation, and testing
  - /dataspecs@datamodule.train:
    - mnli_train
  - /dataspecs@datamodule.val:
    - xnli_val_test
    - indicxnli_val_test
    - amnli_val_test
  - override /module: text_classification

run:
  task: nli

module:
  model:
    pretrained_model_name_or_path: "xlm-roberta-base"
    num_labels: 3

Commandline Interface

hydra allows to simply set configuration items on the commandline. See more information

# change the learning rate
python -m trident.run experiment=nli module.optimizer.lr=0.0001
# set a different optimizer
python -m trident.run experiment=nli module.optimizer=adam
# no lr scheduler
python -m trident.run experiment=nli module.scheduler=null

Warning

The commandline interface only supports absolute paths. For instance, overriding defaults at runtime from the CLI is not possible.