trident package¶
Subpackages¶
- trident.core package
- trident.utils package
Submodules¶
trident.run module¶
Module contents¶
- class trident.TridentDataModule(train=None, val=None, test=None)[source]¶
- Bases: - LightningDataModule- get(split, default=None)[source]¶
- Retrieve the TridentDataspecs for the given split. - This method attempts to fetch a dataspec associated with a specific split. If the split is not found, it returns a default value. - Parameters:
- split ( - Split) – The- Splitused to retrieve the dataspec.
- default ( - Optional[- Any]) – The default value to return if the split is not found.
 
- Return type:
- Optional[- DictList[- TridentDataspec]]
- Returns:
- The - DictListof- TridentDataspecfor the given split or None.
 
 - 
trainer: Trainer[source]¶
- The base class for all datamodules. - The - TridentDataModulefacilitates writing a- LightningDataModulewith little to no boilerplate via Hydra configuration. It splits into- dataset:
- dataloader:
 - Parameters:
- dataset ( - omegaconf.dictconfig.DictConfig) –- A hierarchical - DictConfigthat instantiates or returns the dataset for- self.dataset_{train, val, test}, respectively.- Typical configurations follow the below pattern: 
 - See also - src.utils.hydra.instantiate_and_apply(),- src.utils.hydra.expand()dataloader (- omegaconf.dictconfig.DictConfig):- See also - src.utils.hydra.expand()- Notes - The train, val, and test keys of - datasetand- dataloaderjoin remaining configurations with priority to existing config
- dataloaderautomatically generates train, val, and test keys for convenience as the config is evaluated lazily (i.e. when a- DataLoaderis requested)
 - Example - _target_: src.datamodules.base.TridentDataModule _recursive_: false dataset: _target_: datasets.load.load_dataset # access methods of the instantiated object _method_: map: # dataset.map for e.g. tokenization # kwargs for dataset.map function: _target_: _partial_: true num_proc: 12 path: glue name: mnli train: split: "train" val: # inherits `path`, `name`, etc. split: "validation_mismatched+validation_matched" test: # set `path`, `name`, `lang` specifically, remainder inherited path: xtreme name: xnli lang: de split: "test" dataloader: _target_: torch.utils.data.dataloader.DataLoader batch_size: 8 num_workers: 0 pin_memory: true # linked against global cfg 
 
- class trident.TridentModule(model, optimizer, scheduler=None, initialize_model=True, *args, **kwargs)[source]¶
- Bases: - OptimizerMixin,- EvalMixin- Base module of Trident that wraps model, optimizer, scheduler, evaluation. - Parameters:
- model ( - DictConfig) –- Needs to instantiate a - torch.nn.Modulethat- Takes the batch unpacked 
- Returns a container with “loss” and other required attrs 
 - See also - src.modules.base.TridentModule.forward(),- src.modules.base.TridentModule.training_step(), tiny bert example
- optimizer ( - DictConfig) –- Configuration for the optimizer of your - core.trident.TridentModule.- See also - src.modules.mixin.optimizer.OptimizerMixin, AdamW config
- scheduler ( - Optional[- DictConfig]) –- Configuration for the scheduler of the optimizer of your - src.modules.base.TridentModule.- See also - src.modules.mixin.optimizer.OptimizerMixin, Linear Warm-Up config
- evaluation – - Please refer to evaluation - See also - src.modules.mixin.evaluation.EvalMixin, Classification Evaluation config
 
 - forward(batch)[source]¶
- Plain forward pass of your model for which the batch is unpacked. - Parameters:
- batch ( - dict) – input to your model
- Returns:
- container with attributes required for evaluation 
- Return type:
- ModelOutput 
 
 - training_step(batch, batch_idx)[source]¶
- Comprises training step of your model which takes a forward pass. - Notes:
- If you want to extend training_step, add a on_train_batch_end method via overrides. See: Pytorch-Lightning’s on_train_batch_end 
 - Parameters:
- batch ( - dict) – typically comprising input_ids, attention_mask, and position_ids
- batch_idx ( - int) – variable used internally by pytorch-lightning
 
- Returns:
- model output that must have ‘loss’ as attr or key 
- Return type:
- Union[dict[str, Any], ModelOutput]