trident.core.mixins package¶
Submodules¶
trident.core.mixins.evaluation module¶
- class trident.core.mixins.evaluation.EvalMixin[source]¶
Bases:
LightningModule
- eval_step(split, batch, dataloader_idx)[source]¶
Performs model forward & user batch transformation in an eval step.
- Parameters:
split (
Split
) – The evaluation split.batch (
dict
) – The batch of the evaluation (i.e. ‘val’ or ‘test’) step.dataloader_idx (
int
) – The index of the current evaluation dataloader,None
if single dataloader.
- Return type:
None
Notes
This function is called in validation_step and test_step of the LightningModule.
See also
-
log:
Callable
[source]¶ Mixin for base model to define evaluation loop largely via hydra.
See also LightningModule.
The evaluation mixin enables writing evaluation via yaml files, here is an example for sequence classification, borrowed from configs/evaluation/classification.yaml.
# apply transformation function prepare: batch: null # on each step outputs: # on each step _target_: src.utils.hydra.partial _partial_: src.evaluation.classification.get_preds .. code-block: python # we link evaluation.apply.outputs against get_preds def get_preds(outputs): outputs.preds = outputs.logits.argmax(dim=-1) return outputs step_outputs: null # on flattened outputs of what's collected from steps # Which keys/attributes are supposed to be collected from `outputs` and `batch` step_outputs: outputs: "preds" # can be a str batch: # or a list[str] - labels # either metrics or val_metrics and test_metrics # where the latter metrics: # name of the metric used eg for logging accuracy: # instructions to instantiate metric, preferrably torchmetrics.Metric metric: _target_: torchmetrics.Accuracy # either on_step: true or on_epoch: true on_step: true compute: preds: "outputs:preds" target: "batch:labels" f1: metric: _target_: torchmetrics.F1 on_step: true compute: preds: "outputs:preds" target: "batch:labels"
- log_metric(split, metric_key, input, log_kwargs=None, dataset_name=None)[source]¶
Log a metric for a given split with optional transformation.
- Parameters:
split (
Split
) – The evaluation split.metric_key (
Union
[str
,DictKeyType
]) – Key identifying the metric.input (
Union
[None
,int
,float
,dict
,Tensor
]) – Metric value or dictionary of metric values.log_kwargs (
Optional
[dict
[str
,Any
]]) – Additional keyword arguments for logging.dataset_name (
Optional
[str
]) – Name of the dataset if available.
Notes: - This method assumes the existence of self.evaluation.metrics. - If input is a dictionary, each key-value pair is logged separately with the appropriate prefix.
- on_eval_epoch_end(split)[source]¶
Compute and log metrics for all datasets at the epoch’s end.
Note: the epoch only ends when all datasets are processed.
This method determines if multiple datasets exist for the evaluation split and appropriately logs the metrics for each.
- Parameters:
split (
Split
) – Evaluation split, i.e., “val” or “test”.- Return type:
None
- prepare_metric_input(cfg, outputs, split, batch=None, dataset_name=None)[source]¶
Collects user-defined attributes of outputs & batch to compute a metric.
In the below example, the evaluation (i.e., the call of
accuracy
) extractspreds
fromoutputs
and passes it aspreds
labels
fromoutputs
and passes it astarget
to
accuracy
via dot notation.Note
The variations in types (
dict
or classes with attributes) of the underlying object is handled at runtime.- The following variables are available:
trident_module
outputs
batch
cfg
dataset_name
Notes
trident_module
yields access to the Trainer, which in turn also holdsTridentDatamodule
batch
is only relevant when the metric is called at each stepoutputs
either denotes the output of a step or the concatenated step outputs
Example
- metrics:
- acc:
- metric:
_partial_: true _target_: torchmetrics.functional.accuracy task: “multiclass” num_classes: 3
compute_on: “epoch_end” kwargs:
preds: “outputs.preds” target: “outputs.labels”
- Parameters:
cfg (
Union
[dict
,DictConfig
]) – Configuration dictionary for metric computation.outputs (
Union
[dict
,NamedTuple
]) – Outputs data.batch (
Union
[dict
,NamedTuple
,None
]) – Batch data.
- Returns:
Dictionary containing required inputs for metric computation.
- Return type:
dict
- Raises:
ValueError – If the required key is not found in the provided data.
trident.core.mixins.optimizer module¶
- class trident.core.mixins.optimizer.OptimizerMixin(*args, **kwargs)[source]¶
Bases:
LightningModule
Mixin for base model to define configuration of optimizer and scheduler.
- The OptimizerMixin provides functionality to:
compute the number of training steps (
OptimizerMixin.num_training_steps
)configure the optimizer(s) (
OptimizerMixin.configure_optimizers
)configure the scheduler (
OptimizerMixin.configure_scheduler
)
Examples
Optimizer: AdamW
Scheduler: Linear Warm-Up
- configure_scheduler(optimizer, scheduler_cfg)[source]¶
Configures the LR scheduler for the optimizer.
The instantiation of the scheduler takes the optimizer as the first positional argument.
# hparams.scheduler: passed config scheduler: LambdaLR = hydra.utils.instantiate(self.hparams.scheduler, optimizer,)
- Note that the below values are hard-coded for the time being:
interval: step
frequency: 1
- Parameters:
optimizer (
Optimizer
) – pytorch optimizer- Returns:
scheduler in pytorch-lightning format
- Return type:
dict[str, Union[str, int, LambdaLR]