combustion.lightning¶
Utilities to facilitate operation with PyTorch Lightning
combustion.lightning
-
class
combustion.lightning.
HydraMixin
(*args, **kwargs)[source]¶ Mixin for creating
pytorch_lightning.LightningModule
using Hydra.The following
pytorch_lightning.LightningModule
abstract methods are implemented:train_dataloader
val_dataloader
test_dataloader
-
configure_optimizers
()[source]¶ Override for
pytorch_lightning.LightningModule
that automatically configures optimizers and learning rate scheduling based on a Hydra configuration.The Hydra config should have an
optimizer
section, and optionally aschedule
section if learning rate scheduling is desired.Sample Hydra Config
optimizer: name: adam cls: torch.optim.Adam params: lr: 0.001 schedule: interval: step monitor: val_loss frequency: 1 cls: torch.optim.lr_scheduler.OneCycleLR params: max_lr: ${optimizer.params.lr} epochs: 10 steps_per_epoch: 'none' pct_start: 0.03 div_factor: 10 final_div_factor: 10000.0 anneal_strategy: cos
-
get_lr
(pos=0, param_group=0)[source]¶ Gets the current learning rate. Useful for logging learning rate when using a learning rate schedule.
- Parameters
- Return type
-
static
instantiate
(config, *args, **kwargs)[source]¶ Recursively instantiates classes in a Hydra configuration.
- Parameters
config (omegaconf.DictConfig or dict) – The config to recursively instantiate from.
- Return type
Any
-
prepare_data
(force=False)[source]¶ Override for
pytorch_lightning.LightningModule
that automatically prepares any datasets based on a Hydra configuration. The Hydra config should have andataset
section, and optionally aschedule
section if learning rate scheduling is desired.- The following keys can be provided to compute statistics on the training set
stats_sample_size
- number of training examples to compute statistics over, or"all"
to use the entire training set.stats_dim
- dimension that will not be reduced when computing statistics. This is typically used when examples have more items than a simple(input, target)
tuple, such as when working with masks. Defaults to0
.stats_index
- tuple index of the data to compute statistics for. Defaults to0
.
- The following statistics will be computed and attached as attributes if
stats_sample_size
is set channel_mean
channel_variance
channel_min
channel_max
Note
Training set statistics will be computed and attached when
prepare_data()
the first time. Subsequent calls will not alter the attached statistics.- Parameters
force (bool) – By default, training datasets will only be loaded once. When
force=True
, datasets will always be reloaded.- Return type
Sample Hydra Config
dataset: stats_sample_size: 100 # compute training set statistics using 100 examples stats_dim: 0 # channel dimension to compute statistics for stats_index: 0 # tuple index to select from yielded example train: # passed to DataLoader num_workers: 1 pin_memory: true drop_last: true shuffle: true # instantiates dataset target: torchvision.datasets.FakeData params: size: 10000 image_size: [1, 128, 128] transform: target: torchvision.transforms.ToTensor # test/validation sets can be explicitly given as above, # or as a split from training set # as a random split from training set by number of examples # validate: 32 # as a random split from training set by fraction # test: 0.1
Callbacks¶
-
class
combustion.lightning.callbacks.
TorchScriptCallback
(path=None, trace=False, sample_input=None)[source]¶ Callback to export a model using TorchScript upon completion of training.
Note
A type hint of
pytorch_lightning.LightningModule
,_device: ...
causes problems with TorchScript exports. This type hint must be manually overridden as follows:>>> class MyModule(pl.LightningModule): >>> _device: torch.device >>> ...
- Parameters
path (str, optional) – The filepath where the exported model will be saved. If unset, the model will be saved in the PyTorch Lightning default save path.
trace (bool, optional) – If true, export a
torch.jit.ScriptModule
usingtorch.jit.trace()
. Otherwise,torch.jit.script()
will be used.sample_input (Any, optional) – Sample input data to use with
torch.jit.trace()
. Ifsample_input
is unset andtrace
is true, the attributeexample_input_array
will be used as input. Iftrace
is true andexample_input_array
is unset aRuntimeError
will be raised.
-
on_train_end
(trainer, pl_module)[source]¶ Called after training to export a model using TorchScript.
- Parameters
trainer (pytorch_lightning.trainer.trainer.Trainer) – The
pytorch_lightning.Trainer
instancepl_module (pytorch_lightning.core.lightning.LightningModule) – The
pytorch_lightning.LightningModule
to export.
- Return type
-
class
combustion.lightning.callbacks.
CountMACs
(sample_input=None, custom_ops=None)[source]¶ Callback to output the approximate number of MAC (multiply accumulate) operations and parameters in a model. Runs at start of training.
Note
Counting MACs requires thop
Total MACs / parameters are logged and attached to the model as attributes:
total_macs
total_params
- Parameters
sample_input (optional, Tuple) – Sample input data to use when counting MACs. If
sample_input
is not given the callback will attempt to use attributemodule.example_input_array
as a sample input. If no sample input can be found a warning will be raised.custom_ops (optional, Dict[type, Callable]) – Forwarded to
htop.profile()
-
on_train_start
(trainer, pl_module)[source]¶ Called at start of training
- Parameters
trainer (pytorch_lightning.trainer.trainer.Trainer) – The
pytorch_lightning.Trainer
instancepl_module (pytorch_lightning.core.lightning.LightningModule) – The
pytorch_lightning.LightningModule
to analyze.