combustion.testing¶
Extensions to torch.nn
, ranging from fundamental layers
up to larger building blocks.
combustion.testing
Assertions¶
-
combustion.testing.
assert_has_gradient
(module, recurse=True)[source]¶ Asserts that the parameters in a module have
requires_grad=True
and that the gradient exists.- Parameters
module (torch.nn.Module) – The module to inspect
recurse (bool, optional) – Whether or not to recursively run the same assertion on the gradients of child modules.
-
combustion.testing.
assert_in_eval_mode
(module)[source]¶ Asserts that the module is in inference mode, i.e.
module.eval()
was called.- Parameters
module (torch.nn.Module) – The module to inspect
-
combustion.testing.
assert_in_training_mode
(module)[source]¶ Asserts that the module is in training mode, i.e.
module.train()
was called- Parameters
module (torch.nn.Module) – The module to inspect
-
combustion.testing.
assert_is_int_tensor
(x)[source]¶ Asserts that the values of a floating point tensor are integers. This test is equivalent to
torch.allclose(x, x.round())
.- Parameters
x (torch.Tensor) – The first tensor.
y (torch.Tensor) – The second tensor.
-
combustion.testing.
assert_tensors_close
(x, y, *args, **kwargs)[source]¶ Asserts that the values two tensors are close. This is similar to
torch.allclose()
, but has cleaner output when used with pytest.- Parameters
x (torch.Tensor) – The first tensor.
y (torch.Tensor) – The second tensor.
Additional positional or keyword args are passed to
torch.allclose()
.
-
combustion.testing.
assert_zero_grad
(module, recurse=True)[source]¶ Asserts that the parameters in a module have zero gradients. Useful for checking if Optimizer.zero_grads() was called.
- Parameters
module (torch.nn.Module) – The module to inspect
recurse (bool, optional) – Whether or not to recursively run the same assertion on the gradients of child modules.
PyTest Decorators¶
-
combustion.testing.
cuda_or_skip
(*args, **kwargs)¶ Run test only if
torch.cuda.is_available()
is true.
LightningModuleTest¶
In order to facilitate model testing without excessive boilerplate code,
a pytest base class is provided that attempts to provide a set of minimal
tests for a PyTorch Lightning LightningModule
. By implementing a
small number of fixtures that provide a model to be tested and data
that can be used for testing, many stages of the LightningModule
lifecycle can be tested without writing any additional test code.
-
class
combustion.testing.
LightningModuleTest
[source]¶ Base class to automate testing of LightningModules with Pytest.
The following fixtures should be implemented in the subclass:
model()
data()
- Simple tests are provided for the following lifecycle hooks:
configure_optimizers()
prepare_data()
(optional)train_dataloader()
val_dataloader()
(optional)test_dataloader()
(optional)training_step()
- single process and withtorch.nn.parallel.DistributedDataParallel
validation_step()
(optional)test_step()
(optional)validation_epoch_end()
(optional)test_epoch_end()
(optional)
If the model under test does not implement an optional method, the test will be skipped.
- The following mock attributes will be attached to your model as
PropertyMock
logger
trainer
- Example Usage::
>>> # minimal example >>> class TestModel(LightningModuleTest): >>> @pytest.fixture >>> def model(): >>> return ... # return your model here >>> >>> @pytest.fixture >>> def data(): >>> return torch.rand(2, 1, 10, 10) # will be passed to model.forward()
-
test_configure_optimizers
(model)[source]¶ Tests that
model.configure_optimizers()
runs and returns the required outputs.- Parameters
-
test_forward
(model, data, training)[source]¶ Calls
model.forward()
and tests that the output is notNone
.Because of the size of some models, this test is only run when a GPU is available.
- Parameters
data (torch.Tensor) –
training (bool) –
-
test_prepare_data
(model)[source]¶ Calls
model.prepare_data()
to see if any fatal errors are thrown. No tests are performed to assess change of state- Parameters
-
test_test_dataloader
(model)[source]¶ Tests that
model.test_dataloader()
runs and returns the required output.- Parameters
-
test_test_epoch_end
(model)[source]¶ Tests that
test_epoch_end()
runs and outputs a dict as required by PyTorch Lightning.Because of the size of some models, this test is only run when a GPU is available.
- Parameters
-
test_test_step
(model)[source]¶ Runs a testing step based on the data returned from
model.test_dataloader()
. Tests that the dictionary returned fromtest_step()
are as required by PyTorch Lightning.Because of the size of some models, this test is only run when a GPU is available.
- Parameters
-
test_train_dataloader
(model)[source]¶ Tests that
model.train_dataloader()
runs and returns the required output.- Parameters
-
test_training_step
(model, distributed)[source]¶ Runs a training step based on the data returned from
model.train_dataloader()
. Tests that the dictionary returned fromtraining_step()
are as required by PyTorch Lightning. A backward pass and optimizer step are also performed using the optimizer provided byLightningModule.configure_optimizers()
. By default, training steps are tested for distributed and non-distributed models using thetorch.nn.parallel.DistributedDataParallel
wrapper. Distributed tests can be disabled by settingLightningModuleTest.DISTRIBUTED
toFalse
.Because of the size of some models, this test is only run when a GPU is available.
- Parameters
distributed (bool) –
-
test_val_dataloader
(model)[source]¶ Tests that
model.val_dataloader()
runs and returns the required output.- Parameters
-
test_validation_epoch_end
(model)[source]¶ Tests that
validation_epoch_end()
runs and outputs a dict as required by PyTorch Lightning.Because of the size of some models, this test is only run when a GPU is available.
- Parameters
-
test_validation_step
(model)[source]¶ Runs a validation step based on the data returned from
model.val_dataloader()
. Tests that the dictionary returned fromvalidation_step()
are as required by PyTorch Lightning.Because of the size of some models, this test is only run when a GPU is available.
- Parameters
Mixins¶
-
class
combustion.testing.
TorchScriptTestMixin
[source]¶ Mixin to test a
torch.nn.Module
’s ability to be scripted usingtorch.jit.script()
, saved to disk, and loaded.The following fixtures should be implemented in the subclass:
model()
- returns the model to be tested
-
test_load_scripted
(model, tmp_path)[source]¶ Tests that a
torch.jit.ScriptModule
saved to disk usingtorch.jit.script()
can be loaded, and that the loaded object is atorch.jit.ScriptModule
.
-
test_save_scripted
(model, tmp_path)[source]¶ Calls
torch.jit.script()
on the given model and tests that the resultanttorch.jit.ScriptModule
can be saved to disk usingtorch.jit.save()
.
-
test_script
(model)[source]¶ Calls
torch.jit.script()
on the given model and tests that atorch.jit.ScriptModule
is returned.
-
class
combustion.testing.
TorchScriptTraceTestMixin
[source]¶ Mixin to test a
torch.nn.Module
’s ability to be traced usingtorch.jit.trace()
, saved to disk, and loaded.The following fixtures should be implemented in the subclass:
model()
- returns the model to be testeddata()
- returns an input tomodel.forward()
.
-
test_load_traced
(model, tmp_path, data)[source]¶ Tests that a
torch.jit.ScriptModule
saved to disk usingtorch.jit.trace()
can be loaded, and that the loaded object is atorch.jit.ScriptModule
.
-
test_save_traced
(model, tmp_path, data)[source]¶ Calls
torch.jit.trace()
on the given model and tests that the resultanttorch.jit.ScriptModule
can be saved to disk usingtorch.jit.save()
.
-
test_trace
(model, data)[source]¶ Calls
torch.jit.trace()
on the given model and tests that atorch.jit.ScriptModule
is returned.
-
test_traced_forward_call
(model, data)[source]¶ Calls
torch.jit.trace()
on the given model and tests that atorch.jit.ScriptModule
is returned.Because of the size of some models, this test is only run when a GPU is available.