TridentModule¶
Configuration¶
A TridentModule represents a wrapper around LightningModule to facilitate configuring training, validating, and testing from hydra.
A TridentModule is commonly defined hierarchically:
/config/module/default.yaml: universal defaults/config/module/$TASK.yaml: task-specific configuration
The default.yaml configuration for the TridentModule is typically defined as follows.
# default.yaml:
# _target_ is hydra-lingo to point to the object (class, function) to instantiate
_target_: trident.TridentModule
# _recursive_: true would mean all kwargs are /already/ instantiated
# when passed to `TridentModule`
_recursive_: false
defaults:
# interleaved with setup so instantiated later (recursive false)
- optimizer: adamw.yaml # see config/module/optimizer/adamw.yaml for default
- scheduler: linear_warm_up # see config/module/scheduler/linear_warm_up.yaml for default
# required to be set by user later on
model: ???
A task-specific configuration typically is defined as follows (e.g., nli.yaml):
# nli.yaml:
defaults:
- default
model:
_target_: AutoModelForSequenceClassification.from_pretrained
num_labels: 3
API¶
Methods¶
The below methods are user-facing TridentModule methods. Since TridentModule sub-classes the LightningModule, all methods, attributes, and hooks of the LightningModule are also available.
Important: You should not override the following methods:
validation_steptest_step
since the tridentmodule automatically runs evaluation per the TridentDataspec configuration.
You may override
on_validation_epoch_endon_test_epoch_end
but should make sure to also call the super() method!
forward¶
- TridentModule.forward(batch)[source]
Plain forward pass of your model for which the batch is unpacked.
- Parameters:
batch (
dict) – input to your model- Returns:
container with attributes required for evaluation
- Return type:
ModelOutput
training_step¶
- TridentModule.training_step(batch, batch_idx)[source]
Comprises training step of your model which takes a forward pass.
- Notes:
If you want to extend training_step, add a on_train_batch_end method via overrides. See: Pytorch-Lightning’s on_train_batch_end
- Parameters:
batch (
dict) – typically comprising input_ids, attention_mask, and position_idsbatch_idx (
int) – variable used internally by pytorch-lightning
- Returns:
model output that must have ‘loss’ as attr or key
- Return type:
Union[dict[str, Any], ModelOutput]
log_metric¶
- EvalMixin.log_metric(split, metric_key, input, log_kwargs=None, dataset_name=None)[source]
Log a metric for a given split with optional transformation.
- Parameters:
split (
Split) – The evaluation split.metric_key (
Union[str,DictKeyType]) – Key identifying the metric.input (
Union[None,int,float,dict,Tensor]) – Metric value or dictionary of metric values.log_kwargs (
Optional[dict[str,Any]]) – Additional keyword arguments for logging.dataset_name (
Optional[str]) – Name of the dataset if available.
Notes: - This method assumes the existence of self.evaluation.metrics. - If input is a dictionary, each key-value pair is logged separately with the appropriate prefix.
num_training_steps¶
- property OptimizerMixin.num_training_steps[source]
Infers the number of training steps per device, accounting for gradient accumulation.