TridentModule¶
Configuration¶
A TridentModule
represents a wrapper around LightningModule to facilitate configuring training, validating, and testing from hydra.
A TridentModule
is commonly defined hierarchically:
/config/module/default.yaml
: universal defaults/config/module/$TASK.yaml
: task-specific configuration
The default.yaml
configuration for the TridentModule
is typically defined as follows.
# default.yaml:
# _target_ is hydra-lingo to point to the object (class, function) to instantiate
_target_: trident.TridentModule
# _recursive_: true would mean all kwargs are /already/ instantiated
# when passed to `TridentModule`
_recursive_: false
defaults:
# interleaved with setup so instantiated later (recursive false)
- optimizer: adamw.yaml # see config/module/optimizer/adamw.yaml for default
- scheduler: linear_warm_up # see config/module/scheduler/linear_warm_up.yaml for default
# required to be set by user later on
model: ???
A task-specific configuration typically is defined as follows (e.g., nli.yaml
):
# nli.yaml:
defaults:
- default
model:
_target_: AutoModelForSequenceClassification.from_pretrained
num_labels: 3
API¶
Methods¶
The below methods are user-facing TridentModule
methods. Since TridentModule
sub-classes the LightningModule, all methods, attributes, and hooks of the LightningModule are also available.
Important: You should not override the following methods:
validation_step
test_step
since the tridentmodule
automatically runs evaluation per the TridentDataspec
configuration.
You may override
on_validation_epoch_end
on_test_epoch_end
but should make sure to also call the super()
method!
forward¶
- TridentModule.forward(batch)[source]
Plain forward pass of your model for which the batch is unpacked.
- Parameters:
batch (
dict
) – input to your model- Returns:
container with attributes required for evaluation
- Return type:
ModelOutput
training_step¶
- TridentModule.training_step(batch, batch_idx)[source]
Comprises training step of your model which takes a forward pass.
- Notes:
If you want to extend training_step, add a on_train_batch_end method via overrides. See: Pytorch-Lightning’s on_train_batch_end
- Parameters:
batch (
dict
) – typically comprising input_ids, attention_mask, and position_idsbatch_idx (
int
) – variable used internally by pytorch-lightning
- Returns:
model output that must have ‘loss’ as attr or key
- Return type:
Union[dict[str, Any], ModelOutput]
log_metric¶
- EvalMixin.log_metric(split, metric_key, input, log_kwargs=None, dataset_name=None)[source]
Log a metric for a given split with optional transformation.
- Parameters:
split (
Split
) – The evaluation split.metric_key (
Union
[str
,DictKeyType
]) – Key identifying the metric.input (
Union
[None
,int
,float
,dict
,Tensor
]) – Metric value or dictionary of metric values.log_kwargs (
Optional
[dict
[str
,Any
]]) – Additional keyword arguments for logging.dataset_name (
Optional
[str
]) – Name of the dataset if available.
Notes: - This method assumes the existence of self.evaluation.metrics. - If input is a dictionary, each key-value pair is logged separately with the appropriate prefix.
num_training_steps¶
- property OptimizerMixin.num_training_steps[source]
Infers the number of training steps per device, accounting for gradient accumulation.