trident.core.mixins package

Submodules

trident.core.mixins.evaluation module

class trident.core.mixins.evaluation.EvalMixin[source]

Bases: LightningModule

eval_step(split, batch, dataloader_idx)[source]

Performs model forward & user batch transformation in an eval step.

Parameters:
  • split (Split) – The evaluation split.

  • batch (dict) – The batch of the evaluation (i.e. ‘val’ or ‘test’) step.

  • dataloader_idx (int) – The index of the current evaluation dataloader, None if single dataloader.

Return type:

None

Notes

  • This function is called in validation_step and test_step of the LightningModule.

hparams: AttributeDict[source]
log: Callable[source]

Mixin for base model to define evaluation loop largely via hydra.

See also LightningModule.

The evaluation mixin enables writing evaluation via yaml files, here is an example for sequence classification, borrowed from configs/evaluation/classification.yaml.

# apply transformation function
prepare:
  batch: null # on each step
  outputs:    # on each step
    _target_: src.utils.hydra.partial
    _partial_: src.evaluation.classification.get_preds
    .. code-block: python

        # we link evaluation.apply.outputs against get_preds
        def get_preds(outputs):
            outputs.preds = outputs.logits.argmax(dim=-1)
            return outputs

  step_outputs: null  # on flattened outputs of what's collected from steps
# Which keys/attributes are supposed to be collected from `outputs` and `batch`
step_outputs:
  outputs: "preds" # can be a str
  batch: # or a list[str]
    - labels

# either metrics or val_metrics and test_metrics
# where the latter
metrics:
  # name of the metric used eg for logging
  accuracy:
    # instructions to instantiate metric, preferrably torchmetrics.Metric
    metric:
      _target_: torchmetrics.Accuracy
    # either on_step: true or on_epoch: true
    on_step: true
    compute:
      preds: "outputs:preds"
      target: "batch:labels"
  f1:
    metric:
      _target_: torchmetrics.F1
    on_step: true
    compute:
      preds: "outputs:preds"
      target: "batch:labels"
log_metric(split, metric_key, input, log_kwargs=None, dataset_name=None)[source]

Log a metric for a given split with optional transformation.

Parameters:
  • split (Split) – The evaluation split.

  • metric_key (Union[str, DictKeyType]) – Key identifying the metric.

  • input (Union[None, int, float, dict, Tensor]) – Metric value or dictionary of metric values.

  • log_kwargs (Optional[dict[str, Any]]) – Additional keyword arguments for logging.

  • dataset_name (Optional[str]) – Name of the dataset if available.

Notes: - This method assumes the existence of self.evaluation.metrics. - If input is a dictionary, each key-value pair is logged separately with the appropriate prefix.

on_eval_epoch_end(split)[source]

Compute and log metrics for all datasets at the epoch’s end.

Note: the epoch only ends when all datasets are processed.

This method determines if multiple datasets exist for the evaluation split and appropriately logs the metrics for each.

Parameters:

split (Split) – Evaluation split, i.e., “val” or “test”.

Return type:

None

on_test_epoch_end()[source]
on_validation_epoch_end()[source]
prepare_metric_input(cfg, outputs, split, batch=None, dataset_name=None)[source]

Collects user-defined attributes of outputs & batch to compute a metric.

In the below example, the evaluation (i.e., the call of accuracy) extracts

  1. preds from outputs and passes it as preds

  2. labels from outputs and passes it as target

to accuracy via dot notation.

Note

The variations in types (dict or classes with attributes) of the underlying object is handled at runtime.

The following variables are available:
  • trident_module

  • outputs

  • batch

  • cfg

  • dataset_name

Notes

  • trident_module yields access to the Trainer, which in turn also holds TridentDatamodule

  • batch is only relevant when the metric is called at each step

  • outputs either denotes the output of a step or the concatenated step outputs

Example


metrics:
acc:
metric:

_partial_: true _target_: torchmetrics.functional.accuracy task: “multiclass” num_classes: 3

compute_on: “epoch_end” kwargs:

preds: “outputs.preds” target: “outputs.labels”

Parameters:
  • cfg (Union[dict, DictConfig]) – Configuration dictionary for metric computation.

  • outputs (Union[dict, NamedTuple]) – Outputs data.

  • batch (Union[dict, NamedTuple, None]) – Batch data.

Returns:

Dictionary containing required inputs for metric computation.

Return type:

dict

Raises:

ValueError – If the required key is not found in the provided data.

test_step(batch, batch_idx, dataloader_idx=None)[source]
Return type:

None

validation_step(batch, batch_idx, dataloader_idx=None)[source]
Return type:

None

trident.core.mixins.optimizer module

class trident.core.mixins.optimizer.OptimizerMixin(*args, **kwargs)[source]

Bases: LightningModule

Mixin for base model to define configuration of optimizer and scheduler.

The OptimizerMixin provides functionality to:

Examples

configure_optimizers()[source]

Prepares optimizer and scheduler.

configure_scheduler(optimizer, scheduler_cfg)[source]

Configures the LR scheduler for the optimizer.

The instantiation of the scheduler takes the optimizer as the first positional argument.

# hparams.scheduler: passed config
scheduler: LambdaLR = hydra.utils.instantiate(self.hparams.scheduler, optimizer,)
Note that the below values are hard-coded for the time being:
  • interval: step

  • frequency: 1

Parameters:

optimizer (Optimizer) – pytorch optimizer

Returns:

scheduler in pytorch-lightning format

Return type:

dict[str, Union[str, int, LambdaLR]

hparams: AttributeDict[source]
property num_training_steps[source]

Infers the number of training steps per device, accounting for gradient accumulation.

Module contents