TridentDataspec

Configuration

A TridentDataspec class encapsulates the configuration for data handling in a machine learning workflow. It manages various aspects of data processing including dataset instantiation, preprocessing, dataloading, and evaluation.

Configuration Keys

  • dataset: Specifies how the dataset should be instantiated.

  • dataloader: Defines the instantiation of the DataLoader.

  • preprocessing (optional): Details the methods or function calls for dataset preprocessing.

  • evaluation (optional): Outlines any post-processing steps and metrics for dataset evaluation.

  • misc (optional): Reserved for miscellaneous settings that do not fit under other keys.

dataset

The dataset instantiates a Dataset that is compatible with a PyTorch DataLoader. Any preprocessing should be defined in the corresponding preprocessing configuration of the TridentDataspec.

mnli_train:
  dataset: # required, config on how to instantiate dataset
    _target_: datasets.load_dataset
    path: glue
    name: mnli
    split: train

preprocessing

The preprocessing key in the configuration details the steps for preparing the dataset. It includes two special keys, method and apply, each holding dictionaries for specific preprocessing actions.

  • method: Contains dictionaries of class methods along with their keyword arguments. These are typically methods of the dataset class.

  • apply: Comprises dictionaries of user-defined functions, along with their keyword arguments, to be applied to the dataset. Be mindful that functions of apply, unlike most other keys, typically does not take _partial_: true

The preprocessing fucntions take the Dataset as the first positional argument. The functions are called in order of the configuration. Note that "method" is a convenience keyword which can also be achieved by pointing to the classmethod in "_target_" of an "apply" function.

Example Configuration

preprocessing:
  method:
    map: # dataset.map of huggingface `datasets.arrow_dataset.Dataset`
      function:
        _target_: src.tasks.text_classification.processing.preprocess_fn
        _partial_: true
        column_names:
          text: premise
          text_pair: hypothesis
        tokenizer:
          _partial_: true
          _target_: transformers.tokenization_utils_base.PreTrainedTokenizerBase.__call__
          self:
              _target_: transformers.AutoTokenizer.from_pretrained
              pretrained_model_name_or_path: ${module.model.pretrained_model_name_or_path}
          padding: false
          truncation: true
          max_length: 128
    # unify output format of MNLI and XNLI
    set_format:
      columns:
        - "input_ids"
        - "attention_mask"
        - "label"
  apply:
    example_function:
      _target_: mod.package.example_function
      # ...

dataloader

The DataLoader configuration (configs/dataspec/dataloader/default.yaml) is preset with reasonable defaults, accommodating typical use cases.

Example Configuration

_target_: torch.utils.data.dataloader.DataLoader
collate_fn:
  _target_: transformers.data.data_collator.DataCollatorWithPadding
  tokenizer:
    _target_: transformers.AutoTokenizer.from_pretrained
    pretrained_model_name_or_path: ${module.model.pretrained_model_name_or_path}
  max_length: ???
batch_size: 32
pin_memory: true
shuffle: false
num_workers: 4

evaluation

The logic of evaluation is defined in ./configs/dataspec/evaluation/text_classification.yaml. It is common to define evaluation per type of task.

evaluation configuration segments into the fields prepare, step_outputs, and metrics.

See also

trident.utils.types.EvaluationDict

prepare

prepare defines functions called on the batch, the model outputs, or the collected step_outputs.

The TridentModule hands the below keywords to facilitate evaluation. Since the TridentModule extends the LightningModule, useful attributes like trainer and trainer.datamodule are available at runtime.

Example Configuration

prepare:
  # takes (trident_module: TridentModule, batch: dict, split: Split, dataset_name: str) -> dict
  batch: null
  # takes (trident_module: TridentModule, outputs: dict, batch: dict, split: Split, dataset_name: str) -> dict
  outputs:
    _partial_: true
    _target_: src.tasks.text_classification.evaluation.get_preds
  # takes (trident_module: TridentModule, step_outputs: dict, split: Split, dataset_name: str) -> dict
  step_outputs: null

where get_preds is defined as follows and merely adds

def get_preds(outputs: dict, *args, **kwargs) -> dict:
    outputs["preds"] = outputs["logits"].argmax(dim=-1)
    return outputs

See also

trident.utils.enums.Split, trident.utils.types.PrepareDict

step_outputs

step_outputs defines what keys are collected from a batch or outputs dictionary, per step, into the flattened outputs dict per evaluation dataloader. The flattened dictionary then holds the corresponding key-value pairs as input to the prepare_step_outputs function, which ultimately serves at input to metrics computed at the end of an evaluation loop.

Note

trident ensures that after each evaluation loop, lists of np.ndarrays torch.Tensors are correctly stacked to single array with appropriate dimensions.

Example Configuration

# Which keys/attributes are supposed to be collected from `outputs` and `batch`
step_outputs:
  # can be a str
  batch: labels
  # or a list[str]
  outputs:
    - "preds"
    - "logits"

See also

trident.utils.flatten_dict()

metrics

metrics denotes a dictionary for all evaluated metrics. For instance, a metric such as acc may contain:

  • metric: how to instantiate the metric; typically a partial function; must return a Callable.

  • compute_on: Either eval_step or epoch_end, with the latter being the default.

  • kwargs: A custom syntax to fetch kwargs of metric from one of the following: [trident_module, outputs, batch, cfg]. - outputs refers to the model outputs when compute_on is set to eval_step and to step_outputs when compute_on is set to epoch_end.

In the NLI example:
  • The keyword preds for torchmetrics.functional.accuracy is sourced from outputs["preds"].

  • The keyword target for torchmetrics.functional.accuracy is sourced from outputs["labels"].

Example Configuration

metrics:
  # name of the metric used eg for logging
  acc:
    # instructions to instantiate metric, preferrably torchmetrics.Metric
    metric:
      _partial_: true
      _target_: torchmetrics.functional.accuracy
    # either "eval_step" or "epoch_end", defaults to "epoch_end"
    compute_on: "epoch_end"
    kwargs:
      preds: "outputs.preds"
      target: "outputs.labels"

Note

It is also possible to use metrics for other actions than logging evaluation. The metric function should then return None.

In the below case, we construct a “metric” that logs predictions instead. A metric is not logged to wandb if the function returns None.

def store_predictions(
    trident_module: TridentModule,
    preds: torch.Tensor,
    dataset_name: str,
    directory: str,
    *args,
    **kwargs,
):
    trainer = trident_module.trainer
    epoch = trainer.current_epoch
    p = pd.DataFrame(preds.cpu())
    p.columns = ["prediction"]
    path = Path(directory).joinpath(f"dataset={dataset_name}_epoch={epoch}.csv")
    p.to_csv(str(path), index_label="ids")

And the corresponding metric configuration

metrics:
  store_predictions:
    metric:
      _partial_: true
      _target_: $PATH_TO_FUNCTION.store_predictions
      directory: ${hydra:runtime.output_dir}
    compute_on: "epoch_end"
    kwargs:
      trident_module: "trident_module"
      preds: "outputs.preds"
      dataset_name: "dataset_name"

API

Properties

cfg

The cfg: omegaconf.DictConfig holds the dataspec configuration.

dataset

The dataset as declared in cfg.dataset and preprocssed by cfg.preprocessing.

evaluation

The evaluation as declared in cfg.evaluation.

Methods

get_dataloader

TridentDataspec.get_dataloader()[source]

Creates a DataLoader for the dataset.

Parameters:
  • signature_columns – Columns to be used in the dataloader. Defaults to None.

  • passed (If)

  • misc. (removes unused columns if configured in)

Returns:

The DataLoader configured as per the specified settings.

Return type:

DataLoader