Shortcuts

combustion.lightning

Utilities to facilitate operation with PyTorch Lightning

combustion.lightning

class combustion.lightning.HydraMixin(*args, **kwargs)[source]

Mixin for creating pytorch_lightning.LightningModule using Hydra.

The following pytorch_lightning.LightningModule abstract methods are implemented:

configure_optimizers()[source]

Override for pytorch_lightning.LightningModule that automatically configures optimizers and learning rate scheduling based on a Hydra configuration.

The Hydra config should have an optimizer section, and optionally a schedule section if learning rate scheduling is desired.

Sample Hydra Config

optimizer:
  name: adam
  cls: torch.optim.Adam
  params:
    lr: 0.001

schedule:
  interval: step
  monitor: val_loss
  frequency: 1
  cls: torch.optim.lr_scheduler.OneCycleLR
  params:
    max_lr: ${optimizer.params.lr}
    epochs: 10
    steps_per_epoch: 'none'
    pct_start: 0.03
    div_factor: 10
    final_div_factor: 10000.0
    anneal_strategy: cos
get_lr(pos=0, param_group=0)[source]

Gets the current learning rate. Useful for logging learning rate when using a learning rate schedule.

Parameters
  • pos (int, optional) – The index of the optimizer to retrieve a learning rate from. When using a single optimizer this can be omitted.

  • param_group (int, optional) – The index of the parameter group to retrieve a learning rate for. When using one optimizer for the entire model this can be omitted.

Return type

float

static instantiate(config, *args, **kwargs)[source]

Recursively instantiates classes in a Hydra configuration.

Parameters

config (omegaconf.DictConfig or dict) – The config to recursively instantiate from.

Return type

Any

prepare_data(force=False)[source]

Override for pytorch_lightning.LightningModule that automatically prepares any datasets based on a Hydra configuration. The Hydra config should have an dataset section, and optionally a schedule section if learning rate scheduling is desired.

The following keys can be provided to compute statistics on the training set
  • stats_sample_size - number of training examples to compute statistics over, or "all" to use the entire training set.

  • stats_dim - dimension that will not be reduced when computing statistics. This is typically used when examples have more items than a simple (input, target) tuple, such as when working with masks. Defaults to 0.

  • stats_index - tuple index of the data to compute statistics for. Defaults to 0.

The following statistics will be computed and attached as attributes if stats_sample_size is set
  • channel_mean

  • channel_variance

  • channel_min

  • channel_max

Note

Training set statistics will be computed and attached when prepare_data() the first time. Subsequent calls will not alter the attached statistics.

Parameters

force (bool) – By default, training datasets will only be loaded once. When force=True, datasets will always be reloaded.

Return type

None

Sample Hydra Config

dataset:
  stats_sample_size: 100    # compute training set statistics using 100 examples
  stats_dim: 0              # channel dimension to compute statistics for
  stats_index: 0            # tuple index to select from yielded example
  train:
    # passed to DataLoader
    num_workers: 1
    pin_memory: true
    drop_last: true
    shuffle: true

    # instantiates dataset
    target: torchvision.datasets.FakeData
    params:
      size: 10000
      image_size: [1, 128, 128]
      transform:
        target: torchvision.transforms.ToTensor

  # test/validation sets can be explicitly given as above,
  # or as a split from training set

  # as a random split from training set by number of examples
  # validate: 32

  # as a random split from training set by fraction
  # test: 0.1

Callbacks

class combustion.lightning.callbacks.TorchScriptCallback(path=None, trace=False, sample_input=None)[source]

Callback to export a model using TorchScript upon completion of training.

Note

A type hint of pytorch_lightning.LightningModule, _device: ... causes problems with TorchScript exports. This type hint must be manually overridden as follows:

>>> class MyModule(pl.LightningModule):
>>>     _device: torch.device
>>>     ...
Parameters
  • path (str, optional) – The filepath where the exported model will be saved. If unset, the model will be saved in the PyTorch Lightning default save path.

  • trace (bool, optional) – If true, export a torch.jit.ScriptModule using torch.jit.trace(). Otherwise, torch.jit.script() will be used.

  • sample_input (Any, optional) – Sample input data to use with torch.jit.trace(). If sample_input is unset and trace is true, the attribute example_input_array will be used as input. If trace is true and example_input_array is unset a RuntimeError will be raised.

on_train_end(trainer, pl_module)[source]

Called after training to export a model using TorchScript.

Parameters
Return type

None

class combustion.lightning.callbacks.CountMACs(sample_input=None, custom_ops=None)[source]

Callback to output the approximate number of MAC (multiply accumulate) operations and parameters in a model. Runs at start of training.

Note

Counting MACs requires thop

Total MACs / parameters are logged and attached to the model as attributes:

  • total_macs

  • total_params

Parameters
  • sample_input (optional, Tuple) – Sample input data to use when counting MACs. If sample_input is not given the callback will attempt to use attribute module.example_input_array as a sample input. If no sample input can be found a warning will be raised.

  • custom_ops (optional, Dict[type, Callable]) – Forwarded to htop.profile()

on_train_start(trainer, pl_module)[source]

Called at start of training

Parameters
Read the Docs v: v0.1.0rc1
Versions
latest
docs
v0.1.0rc1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources