Shortcuts

combustion.testing

Extensions to torch.nn, ranging from fundamental layers up to larger building blocks.

Assertions

combustion.testing.assert_has_gradient(module, recurse=True)[source]

Asserts that the parameters in a module have requires_grad=True and that the gradient exists.

Parameters
  • module (torch.nn.Module) – The module to inspect

  • recurse (bool, optional) – Whether or not to recursively run the same assertion on the gradients of child modules.

combustion.testing.assert_in_eval_mode(module)[source]

Asserts that the module is in inference mode, i.e. module.eval() was called.

Parameters

module (torch.nn.Module) – The module to inspect

combustion.testing.assert_in_training_mode(module)[source]

Asserts that the module is in training mode, i.e. module.train() was called

Parameters

module (torch.nn.Module) – The module to inspect

combustion.testing.assert_is_int_tensor(x)[source]

Asserts that the values of a floating point tensor are integers. This test is equivalent to torch.allclose(x, x.round()).

Parameters
combustion.testing.assert_tensors_close(x, y, *args, **kwargs)[source]

Asserts that the values two tensors are close. This is similar to torch.allclose(), but has cleaner output when used with pytest.

Parameters

Additional positional or keyword args are passed to torch.allclose().

combustion.testing.assert_zero_grad(module, recurse=True)[source]

Asserts that the parameters in a module have zero gradients. Useful for checking if Optimizer.zero_grads() was called.

Parameters
  • module (torch.nn.Module) – The module to inspect

  • recurse (bool, optional) – Whether or not to recursively run the same assertion on the gradients of child modules.

PyTest Decorators

combustion.testing.cuda_or_skip(*args, **kwargs)

Run test only if torch.cuda.is_available() is true.

Parameters

LightningModuleTest

In order to facilitate model testing without excessive boilerplate code, a pytest base class is provided that attempts to provide a set of minimal tests for a PyTorch Lightning LightningModule. By implementing a small number of fixtures that provide a model to be tested and data that can be used for testing, many stages of the LightningModule lifecycle can be tested without writing any additional test code.

class combustion.testing.LightningModuleTest[source]

Base class to automate testing of LightningModules with Pytest.

The following fixtures should be implemented in the subclass:

  • model()

  • data()

Simple tests are provided for the following lifecycle hooks:
  • configure_optimizers()

  • prepare_data() (optional)

  • train_dataloader()

  • val_dataloader() (optional)

  • test_dataloader() (optional)

  • training_step() - single process and with torch.nn.parallel.DistributedDataParallel

  • validation_step() (optional)

  • test_step() (optional)

  • validation_epoch_end() (optional)

  • test_epoch_end() (optional)

If the model under test does not implement an optional method, the test will be skipped.

The following mock attributes will be attached to your model as PropertyMock
  • logger

  • trainer

Example Usage::
>>> # minimal example
>>> class TestModel(LightningModuleTest):
>>>     @pytest.fixture
>>>     def model():
>>>         return ... # return your model here
>>>
>>>     @pytest.fixture
>>>     def data():
>>>         return torch.rand(2, 1, 10, 10) # will be passed to model.forward()
test_configure_optimizers(model)[source]

Tests that model.configure_optimizers() runs and returns the required outputs.

Parameters

model (pytorch_lightning.core.lightning.LightningModule) –

test_forward(model, data, training)[source]

Calls model.forward() and tests that the output is not None.

Because of the size of some models, this test is only run when a GPU is available.

Parameters
test_prepare_data(model)[source]

Calls model.prepare_data() to see if any fatal errors are thrown. No tests are performed to assess change of state

Parameters

model (pytorch_lightning.core.lightning.LightningModule) –

test_test_dataloader(model)[source]

Tests that model.test_dataloader() runs and returns the required output.

Parameters

model (pytorch_lightning.core.lightning.LightningModule) –

test_test_epoch_end(model)[source]

Tests that test_epoch_end() runs and outputs a dict as required by PyTorch Lightning.

Because of the size of some models, this test is only run when a GPU is available.

Parameters

model (pytorch_lightning.core.lightning.LightningModule) –

test_test_step(model)[source]

Runs a testing step based on the data returned from model.test_dataloader(). Tests that the dictionary returned from test_step() are as required by PyTorch Lightning.

Because of the size of some models, this test is only run when a GPU is available.

Parameters

model (pytorch_lightning.core.lightning.LightningModule) –

test_train_dataloader(model)[source]

Tests that model.train_dataloader() runs and returns the required output.

Parameters

model (pytorch_lightning.core.lightning.LightningModule) –

test_training_step(model, distributed)[source]

Runs a training step based on the data returned from model.train_dataloader(). Tests that the dictionary returned from training_step() are as required by PyTorch Lightning. A backward pass and optimizer step are also performed using the optimizer provided by LightningModule.configure_optimizers(). By default, training steps are tested for distributed and non-distributed models using the torch.nn.parallel.DistributedDataParallel wrapper. Distributed tests can be disabled by setting LightningModuleTest.DISTRIBUTED to False.

Because of the size of some models, this test is only run when a GPU is available.

Parameters
test_val_dataloader(model)[source]

Tests that model.val_dataloader() runs and returns the required output.

Parameters

model (pytorch_lightning.core.lightning.LightningModule) –

test_validation_epoch_end(model)[source]

Tests that validation_epoch_end() runs and outputs a dict as required by PyTorch Lightning.

Because of the size of some models, this test is only run when a GPU is available.

Parameters

model (pytorch_lightning.core.lightning.LightningModule) –

test_validation_step(model)[source]

Runs a validation step based on the data returned from model.val_dataloader(). Tests that the dictionary returned from validation_step() are as required by PyTorch Lightning.

Because of the size of some models, this test is only run when a GPU is available.

Parameters

model (pytorch_lightning.core.lightning.LightningModule) –

Mixins

class combustion.testing.TorchScriptTestMixin[source]

Mixin to test a torch.nn.Module’s ability to be scripted using torch.jit.script(), saved to disk, and loaded.

The following fixtures should be implemented in the subclass:

  • model() - returns the model to be tested

test_load_scripted(model, tmp_path)[source]

Tests that a torch.jit.ScriptModule saved to disk using torch.jit.script() can be loaded, and that the loaded object is a torch.jit.ScriptModule.

test_save_scripted(model, tmp_path)[source]

Calls torch.jit.script() on the given model and tests that the resultant torch.jit.ScriptModule can be saved to disk using torch.jit.save().

test_script(model)[source]

Calls torch.jit.script() on the given model and tests that a torch.jit.ScriptModule is returned.

class combustion.testing.TorchScriptTraceTestMixin[source]

Mixin to test a torch.nn.Module’s ability to be traced using torch.jit.trace(), saved to disk, and loaded.

The following fixtures should be implemented in the subclass:

  • model() - returns the model to be tested

  • data() - returns an input to model.forward().

test_load_traced(model, tmp_path, data)[source]

Tests that a torch.jit.ScriptModule saved to disk using torch.jit.trace() can be loaded, and that the loaded object is a torch.jit.ScriptModule.

test_save_traced(model, tmp_path, data)[source]

Calls torch.jit.trace() on the given model and tests that the resultant torch.jit.ScriptModule can be saved to disk using torch.jit.save().

test_trace(model, data)[source]

Calls torch.jit.trace() on the given model and tests that a torch.jit.ScriptModule is returned.

test_traced_forward_call(model, data)[source]

Calls torch.jit.trace() on the given model and tests that a torch.jit.ScriptModule is returned.

Because of the size of some models, this test is only run when a GPU is available.

Read the Docs v: 0.1.0rc2
Versions
latest
docs
0.1.0rc2
v0.1.0rc1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources