TridentModule

Configuration

A TridentModule represents a wrapper around LightningModule to facilitate configuring training, validating, and testing from hydra.

A TridentModule is commonly defined hierarchically:

  1. /config/module/default.yaml: universal defaults

  2. /config/module/$TASK.yaml: task-specific configuration

The default.yaml configuration for the TridentModule is typically defined as follows.

# default.yaml:
# _target_ is hydra-lingo to point to the object (class, function) to instantiate
_target_: trident.TridentModule
# _recursive_: true would mean all kwargs are /already/ instantiated
# when passed to `TridentModule`
_recursive_: false

defaults:
# interleaved with setup so instantiated later (recursive false)
- optimizer: adamw.yaml  # see config/module/optimizer/adamw.yaml for default
- scheduler: linear_warm_up  # see config/module/scheduler/linear_warm_up.yaml for default

# required to be set by user later on
model: ???

A task-specific configuration typically is defined as follows (e.g., nli.yaml):

# nli.yaml:
defaults:
- default

model:
  _target_: AutoModelForSequenceClassification.from_pretrained
  num_labels: 3

API

Methods

The below methods are user-facing TridentModule methods. Since TridentModule sub-classes the LightningModule, all methods, attributes, and hooks of the LightningModule are also available.

Important: You should not override the following methods:

  • validation_step

  • test_step

since the tridentmodule automatically runs evaluation per the TridentDataspec configuration.

You may override

  • on_validation_epoch_end

  • on_test_epoch_end

but should make sure to also call the super() method!

forward

TridentModule.forward(batch)[source]

Plain forward pass of your model for which the batch is unpacked.

Parameters:

batch (dict) – input to your model

Returns:

container with attributes required for evaluation

Return type:

ModelOutput

training_step

TridentModule.training_step(batch, batch_idx)[source]

Comprises training step of your model which takes a forward pass.

Notes:

If you want to extend training_step, add a on_train_batch_end method via overrides. See: Pytorch-Lightning’s on_train_batch_end

Parameters:
  • batch (dict) – typically comprising input_ids, attention_mask, and position_ids

  • batch_idx (int) – variable used internally by pytorch-lightning

Returns:

model output that must have ‘loss’ as attr or key

Return type:

Union[dict[str, Any], ModelOutput]

log_metric

EvalMixin.log_metric(split, metric_key, input, log_kwargs=None, dataset_name=None)[source]

Log a metric for a given split with optional transformation.

Parameters:
  • split (Split) – The evaluation split.

  • metric_key (Union[str, DictKeyType]) – Key identifying the metric.

  • input (Union[None, int, float, dict, Tensor]) – Metric value or dictionary of metric values.

  • log_kwargs (Optional[dict[str, Any]]) – Additional keyword arguments for logging.

  • dataset_name (Optional[str]) – Name of the dataset if available.

Notes: - This method assumes the existence of self.evaluation.metrics. - If input is a dictionary, each key-value pair is logged separately with the appropriate prefix.

num_training_steps

property OptimizerMixin.num_training_steps[source]

Infers the number of training steps per device, accounting for gradient accumulation.