trident.core package¶
Subpackages¶
Submodules¶
trident.core.datamodule module¶
- class trident.core.datamodule.TridentDataModule(train=None, val=None, test=None)[source]¶
Bases:
LightningDataModule
- get(split, default=None)[source]¶
Retrieve the TridentDataspecs for the given split.
This method attempts to fetch a dataspec associated with a specific split. If the split is not found, it returns a default value.
- Parameters:
split (
Split
) – TheSplit
used to retrieve the dataspec.default (
Optional
[Any
]) – The default value to return if the split is not found.
- Return type:
Optional
[DictList
[TridentDataspec
]]- Returns:
The
DictList
ofTridentDataspec
for the given split or None.
-
trainer:
Trainer
[source]¶ The base class for all datamodules.
The
TridentDataModule
facilitates writing aLightningDataModule
with little to no boilerplate via Hydra configuration. It splits intodataset
:dataloader
:
- Parameters:
dataset (
omegaconf.dictconfig.DictConfig
) –A hierarchical
DictConfig
that instantiates or returns the dataset forself.dataset_{train, val, test}
, respectively.Typical configurations follow the below pattern:
See also
src.utils.hydra.instantiate_and_apply()
,src.utils.hydra.expand()
dataloader (omegaconf.dictconfig.DictConfig
):See also
src.utils.hydra.expand()
Notes
The train, val, and test keys of
dataset
anddataloader
join remaining configurations with priority to existing configdataloader
automatically generates train, val, and test keys for convenience as the config is evaluated lazily (i.e. when aDataLoader
is requested)
Example
_target_: src.datamodules.base.TridentDataModule _recursive_: false dataset: _target_: datasets.load.load_dataset # access methods of the instantiated object _method_: map: # dataset.map for e.g. tokenization # kwargs for dataset.map function: _target_: _partial_: true num_proc: 12 path: glue name: mnli train: split: "train" val: # inherits `path`, `name`, etc. split: "validation_mismatched+validation_matched" test: # set `path`, `name`, `lang` specifically, remainder inherited path: xtreme name: xnli lang: de split: "test" dataloader: _target_: torch.utils.data.dataloader.DataLoader batch_size: 8 num_workers: 0 pin_memory: true # linked against global cfg
trident.core.dataspec module¶
- class trident.core.dataspec.TridentDataspec(cfg, name='None')[source]¶
Bases:
object
A class to handle data specification in trident.
This class is designed to instantiate the dataset, preprocess the dataset, and create data loaders for training, validation, or testing.
The preprocessing configuration includes two special keys: - ‘method’: Holds dictionaries of class methods and their keyword arguments for preprocessing. - ‘apply’: Contains dictionaries for user-defined functions and their keyword arguments to apply on the dataset.
- cfg[source]¶
The configuration object that contains all the settings for dataset instantiation, preprocessing, dataloader setup, and evaluation metrics.
- Type:
DictConfig
- get_dataloader()[source]¶
Creates a DataLoader for the dataset.
- Parameters:
signature_columns – Columns to be used in the dataloader. Defaults to None.
passed (If)
misc. (removes unused columns if configured in)
- Returns:
The DataLoader configured as per the specified settings.
- Return type:
DataLoader
- static preprocess(dataset, cfg)[source]¶
Applies preprocessing steps to the dataset as specified in the config.
The
cfg
includes two special keys: -"method"
: Holds dictionaries of class methods and their keyword arguments for preprocessing. -"apply"
: Contains dictionaries for user-defined functions and their keyword arguments to apply on the dataset.The preprocessing fucntions take the
Dataset
as the first positional argument. The functions are called in order of the configuration. Note that"method"
is a convenience keyword which can also be achieved by pointing to the classmethod in"_target_"
of an"apply"
function.- Parameters:
dataset (
Any
) – The dataset to be preprocessed.cfg (
Optional
[PreprocessingDict
]) – A dictionary of preprocessing configurations.
- Returns:
The preprocessed dataset.
- Return type:
Any
trident.core.module module¶
- class trident.core.module.TridentModule(model, optimizer, scheduler=None, initialize_model=True, *args, **kwargs)[source]¶
Bases:
OptimizerMixin
,EvalMixin
Base module of Trident that wraps model, optimizer, scheduler, evaluation.
- Parameters:
model (
DictConfig
) –Needs to instantiate a
torch.nn.Module
thatTakes the batch unpacked
Returns a container with “loss” and other required attrs
See also
src.modules.base.TridentModule.forward()
,src.modules.base.TridentModule.training_step()
, tiny bert exampleoptimizer (
DictConfig
) –Configuration for the optimizer of your
core.trident.TridentModule
.See also
src.modules.mixin.optimizer.OptimizerMixin
, AdamW configscheduler (
Optional
[DictConfig
]) –Configuration for the scheduler of the optimizer of your
src.modules.base.TridentModule
.See also
src.modules.mixin.optimizer.OptimizerMixin
, Linear Warm-Up configevaluation –
Please refer to evaluation
See also
src.modules.mixin.evaluation.EvalMixin
, Classification Evaluation config
- forward(batch)[source]¶
Plain forward pass of your model for which the batch is unpacked.
- Parameters:
batch (
dict
) – input to your model- Returns:
container with attributes required for evaluation
- Return type:
ModelOutput
- training_step(batch, batch_idx)[source]¶
Comprises training step of your model which takes a forward pass.
- Notes:
If you want to extend training_step, add a on_train_batch_end method via overrides. See: Pytorch-Lightning’s on_train_batch_end
- Parameters:
batch (
dict
) – typically comprising input_ids, attention_mask, and position_idsbatch_idx (
int
) – variable used internally by pytorch-lightning
- Returns:
model output that must have ‘loss’ as attr or key
- Return type:
Union[dict[str, Any], ModelOutput]
Module contents¶
- class trident.core.TridentDataModule(train=None, val=None, test=None)[source]¶
Bases:
LightningDataModule
- get(split, default=None)[source]¶
Retrieve the TridentDataspecs for the given split.
This method attempts to fetch a dataspec associated with a specific split. If the split is not found, it returns a default value.
- Parameters:
split (
Split
) – TheSplit
used to retrieve the dataspec.default (
Optional
[Any
]) – The default value to return if the split is not found.
- Return type:
Optional
[DictList
[TridentDataspec
]]- Returns:
The
DictList
ofTridentDataspec
for the given split or None.
-
trainer:
Trainer
[source]¶ The base class for all datamodules.
The
TridentDataModule
facilitates writing aLightningDataModule
with little to no boilerplate via Hydra configuration. It splits intodataset
:dataloader
:
- Parameters:
dataset (
omegaconf.dictconfig.DictConfig
) –A hierarchical
DictConfig
that instantiates or returns the dataset forself.dataset_{train, val, test}
, respectively.Typical configurations follow the below pattern:
See also
src.utils.hydra.instantiate_and_apply()
,src.utils.hydra.expand()
dataloader (omegaconf.dictconfig.DictConfig
):See also
src.utils.hydra.expand()
Notes
The train, val, and test keys of
dataset
anddataloader
join remaining configurations with priority to existing configdataloader
automatically generates train, val, and test keys for convenience as the config is evaluated lazily (i.e. when aDataLoader
is requested)
Example
_target_: src.datamodules.base.TridentDataModule _recursive_: false dataset: _target_: datasets.load.load_dataset # access methods of the instantiated object _method_: map: # dataset.map for e.g. tokenization # kwargs for dataset.map function: _target_: _partial_: true num_proc: 12 path: glue name: mnli train: split: "train" val: # inherits `path`, `name`, etc. split: "validation_mismatched+validation_matched" test: # set `path`, `name`, `lang` specifically, remainder inherited path: xtreme name: xnli lang: de split: "test" dataloader: _target_: torch.utils.data.dataloader.DataLoader batch_size: 8 num_workers: 0 pin_memory: true # linked against global cfg
- class trident.core.TridentModule(model, optimizer, scheduler=None, initialize_model=True, *args, **kwargs)[source]¶
Bases:
OptimizerMixin
,EvalMixin
Base module of Trident that wraps model, optimizer, scheduler, evaluation.
- Parameters:
model (
DictConfig
) –Needs to instantiate a
torch.nn.Module
thatTakes the batch unpacked
Returns a container with “loss” and other required attrs
See also
src.modules.base.TridentModule.forward()
,src.modules.base.TridentModule.training_step()
, tiny bert exampleoptimizer (
DictConfig
) –Configuration for the optimizer of your
core.trident.TridentModule
.See also
src.modules.mixin.optimizer.OptimizerMixin
, AdamW configscheduler (
Optional
[DictConfig
]) –Configuration for the scheduler of the optimizer of your
src.modules.base.TridentModule
.See also
src.modules.mixin.optimizer.OptimizerMixin
, Linear Warm-Up configevaluation –
Please refer to evaluation
See also
src.modules.mixin.evaluation.EvalMixin
, Classification Evaluation config
- forward(batch)[source]¶
Plain forward pass of your model for which the batch is unpacked.
- Parameters:
batch (
dict
) – input to your model- Returns:
container with attributes required for evaluation
- Return type:
ModelOutput
- training_step(batch, batch_idx)[source]¶
Comprises training step of your model which takes a forward pass.
- Notes:
If you want to extend training_step, add a on_train_batch_end method via overrides. See: Pytorch-Lightning’s on_train_batch_end
- Parameters:
batch (
dict
) – typically comprising input_ids, attention_mask, and position_idsbatch_idx (
int
) – variable used internally by pytorch-lightning
- Returns:
model output that must have ‘loss’ as attr or key
- Return type:
Union[dict[str, Any], ModelOutput]