trident package

Subpackages

Submodules

trident.run module

trident.run.instantiate_objects(cfg, key)[source]
Return type:

List[Union[Callback, Logger]]

trident.run.main(cfg)[source]
trident.run.run(cfg)[source]

Contains training pipeline. Instantiates all PyTorch Lightning objects from config.

Parameters:

cfg (DictConfig) – Configuration composed by Hydra.

Returns:

Metric score for hyperparameter optimization.

Return type:

Optional[float]

Module contents

class trident.TridentDataModule(train=None, val=None, test=None)[source]

Bases: LightningDataModule

get(split, default=None)[source]

Retrieve the TridentDataspecs for the given split.

This method attempts to fetch a dataspec associated with a specific split. If the split is not found, it returns a default value.

Parameters:
  • split (Split) – The Split used to retrieve the dataspec.

  • default (Optional[Any]) – The default value to return if the split is not found.

Return type:

Optional[DictList[TridentDataspec]]

Returns:

The DictList of TridentDataspec for the given split or None.

predict_dataloader()[source]
Return type:

Union[DataLoader, CombinedLoader]

setup(stage=None)[source]
Return type:

None

test_dataloader()[source]
Return type:

Union[DataLoader, CombinedLoader]

train_dataloader()[source]
Return type:

Union[DataLoader, CombinedLoader]

trainer: Trainer[source]

The base class for all datamodules.

The TridentDataModule facilitates writing a LightningDataModule with little to no boilerplate via Hydra configuration. It splits into

  • dataset:

  • dataloader:

Parameters:

dataset (omegaconf.dictconfig.DictConfig) –

A hierarchical DictConfig that instantiates or returns the dataset for self.dataset_{train, val, test}, respectively.

Typical configurations follow the below pattern:

See also

src.utils.hydra.instantiate_and_apply(), src.utils.hydra.expand() dataloader (omegaconf.dictconfig.DictConfig):

See also

src.utils.hydra.expand()

Notes

  • The train, val, and test keys of dataset and dataloader join remaining configurations with priority to existing config

  • dataloader automatically generates train, val, and test keys for convenience as the config is evaluated lazily (i.e. when a DataLoader is requested)

Example

_target_: src.datamodules.base.TridentDataModule
_recursive_: false

dataset:
  _target_: datasets.load.load_dataset
  # access methods of the instantiated object
  _method_:
    map: # dataset.map for e.g. tokenization
      # kwargs for dataset.map
      function:
        _target_:
        _partial_: true
      num_proc: 12
  path: glue
  name: mnli
  train:
    split: "train"
  val:
    # inherits `path`, `name`, etc.
    split: "validation_mismatched+validation_matched"
  test:
    # set `path`, `name`, `lang` specifically, remainder inherited
    path: xtreme
    name: xnli
    lang: de
    split: "test"
dataloader:
  _target_: torch.utils.data.dataloader.DataLoader
  batch_size: 8
  num_workers: 0
  pin_memory: true
  # linked against global cfg
val_dataloader()[source]
Return type:

Union[DataLoader, CombinedLoader]

class trident.TridentModule(model, optimizer, scheduler=None, initialize_model=True, *args, **kwargs)[source]

Bases: OptimizerMixin, EvalMixin

Base module of Trident that wraps model, optimizer, scheduler, evaluation.

Parameters:
  • model (DictConfig) –

    Needs to instantiate a torch.nn.Module that

    • Takes the batch unpacked

    • Returns a container with “loss” and other required attrs

    See also

    src.modules.base.TridentModule.forward(), src.modules.base.TridentModule.training_step(), tiny bert example

  • optimizer (DictConfig) –

    Configuration for the optimizer of your core.trident.TridentModule.

    See also

    src.modules.mixin.optimizer.OptimizerMixin, AdamW config

  • scheduler (Optional[DictConfig]) –

    Configuration for the scheduler of the optimizer of your src.modules.base.TridentModule.

    See also

    src.modules.mixin.optimizer.OptimizerMixin, Linear Warm-Up config

  • evaluation

    Please refer to evaluation

    See also

    src.modules.mixin.evaluation.EvalMixin, Classification Evaluation config

forward(batch)[source]

Plain forward pass of your model for which the batch is unpacked.

Parameters:

batch (dict) – input to your model

Returns:

container with attributes required for evaluation

Return type:

ModelOutput

training_step(batch, batch_idx)[source]

Comprises training step of your model which takes a forward pass.

Notes:

If you want to extend training_step, add a on_train_batch_end method via overrides. See: Pytorch-Lightning’s on_train_batch_end

Parameters:
  • batch (dict) – typically comprising input_ids, attention_mask, and position_ids

  • batch_idx (int) – variable used internally by pytorch-lightning

Returns:

model output that must have ‘loss’ as attr or key

Return type:

Union[dict[str, Any], ModelOutput]