Combustion documentation¶
Combustion is a collection of useful PyTorch utilities.
Training Helpers¶
-
combustion.
initialize
(config_path, config_name, caller_stack_depth=1)[source]¶ Performs initialization needed for configuring multiruns / parameter sweeps via a YAML config file. Currently this is only needed if multirun configuration via a YAML file is desired. Otherwise,
combustion.main()
can be used without callinginitialize
. Seecombustion.main()
for usage examples.Warning
This method makes use of Hydra’s compose API, which is experimental as of version 1.0.
Warning
This method works by inspecting the “sweeper” section of the specified config file and altering
sys.argv
to include the chosen sweeper parameters.- Parameters
config_path (str) – Path to main configuration file. See
hydra.main()
for more details.config_name (str) – Name of the main configuration file. See
hydra.main()
for more details.caller_stack_depth (int) – Stack depth when calling
initialize()
. Defaults to 1 (direct caller).
- Return type
- Sample sweeper Hydra config
sweeper: model.params.batch_size: 8,16,32 optimizer.params.lr: 0.001,0.002,0.0003
-
combustion.
main
(cfg, process_results_fn=None)[source]¶ Main method for training/testing of a model using PyTorch Lightning and Hydra.
This method is robust to exceptions (other than
SystemExit
orKeyboardInterrupt
), making it useful when using Hydra’s multirun feature. If one combination of hyperparameters results in an exception, other combinations will still be attempted. This behavior can be overriden by providing acheck_exceptions
bool value underconfig.trainer
. Such an override is useful when writing tests.Automatic learning rate selection is handled automatically using
auto_lr_find()
.Training / testing is automatically performed based on the configuration keys present in
config.dataset
.- Parameters
cfg (DictConfig) – The Hydra config
process_results_fn (callable, optional) – If given, call
process_results_fn
on the(train_results, test_results)
tuple returned by this method. This is useful for processing training/testing results into a scalar return value when using an optimization sweeper (like Ax).
- Return type
Example:
>>> # define main method as per Hydra that calls combustion.main() >>> @hydra.main(config_path="./conf", config_name="config") >>> def main(cfg): >>> combustion.main(cfg) >>> >>> if __name__ == "__main__": >>> main()
Example (multirun from config file):
>>> combustion.initialize(config_path="./conf", config_name="config") >>> >>> @hydra.main(config_path="./conf", config_name="config") >>> def main(cfg): >>> return combustion.main(cfg) >>> >>> if __name__ == "__main__": >>> main() >>> combustion.check_exceptions()
-
combustion.
check_exceptions
()[source]¶ Checks if exceptions have been raised over the course of a multirun. Most exceptions are ignored by
combustion.main()
to prevent a failed run from killing an entire hyperparameter search. However, one may still want to raise an exception at the conclusion of a multirun (i.e. for testing purposes). This method checks if any exceptions were raised, and if so will raise acombustion.MultiRunError
.Example:
>>> @hydra.main(config_path="./conf", config_name="config") >>> def main(cfg): >>> combustion.main(cfg) >>> >>> if __name__ == "__main__": >>> main() >>> combustion.check_exceptions()
-
combustion.
auto_lr_find
(cfg, model)[source]¶ Performs automatic learning rate selection using PyTorch Lightning. This is essentially a wrapper function that invokes PyTorch Lightning’s auto LR selection using Hydra inputs. The model’s learning rate is automatically set to the selected learning rate, and the selected learning rate is logged. If possible, a plot of the learning rate selection curve will also be produced.
- Parameters
cfg (DictConfig) – The Hydra config
model (LightningModule) – The model to select a learning rate for.
- Returns
The learning rate if one was found, otherwise
None
.- Return type
Optional[float]