Shortcuts

Source code for combustion.testing.assertions

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
from torch import Tensor
from torch.testing import assert_allclose


[docs]def assert_has_gradient(module: nn.Module, recurse: bool = True): r"""Asserts that the parameters in a module have ``requires_grad=True`` and that the gradient exists. Args: module (torch.nn.Module): The module to inspect recurse (bool, optional): Whether or not to recursively run the same assertion on the gradients of child modules. """ __tracebackhide__ = True if isinstance(module, torch.Tensor) and module.grad is None: raise AssertionError(f"tensor grad == {module.grad}") elif isinstance(module, torch.nn.Module): for name, param in module.named_parameters(recurse=recurse): if param.requires_grad and param.grad is None: raise AssertionError(f"param {name} grad == {param.grad}")
[docs]def assert_zero_grad(module: nn.Module, recurse: bool = True): r"""Asserts that the parameters in a module have zero gradients. Useful for checking if `Optimizer.zero_grads()` was called. Args: module (torch.nn.Module): The module to inspect recurse (bool, optional): Whether or not to recursively run the same assertion on the gradients of child modules. """ __tracebackhide__ = True if isinstance(module, torch.Tensor) and not all(module.grad == 0): raise AssertionError(f"module.grad == {module.grad}") elif isinstance(module, torch.nn.Module): for name, param in module.named_parameters(recurse=recurse): if param.requires_grad and not (param.grad is None or (~param.grad.bool()).all()): raise AssertionError(f"param {name} grad == {param.grad}")
[docs]def assert_in_training_mode(module: nn.Module): r"""Asserts that the module is in training mode, i.e. ``module.train()`` was called Args: module (torch.nn.Module): The module to inspect """ __tracebackhide__ = True if not module.training: raise AssertionError(f"module.training == {module.training}")
[docs]def assert_in_eval_mode(module: nn.Module): r"""Asserts that the module is in inference mode, i.e. ``module.eval()`` was called. Args: module (torch.nn.Module): The module to inspect """ __tracebackhide__ = True if module.training: raise AssertionError(f"module.training == {module.training}")
[docs]def assert_tensors_close(x: Tensor, y: Tensor, *args, **kwargs): r"""Asserts that the values two tensors are close. This is similar to :func:`torch.allclose`, but has cleaner output when used with pytest. Args: x (torch.Tensor): The first tensor. y (torch.Tensor): The second tensor. Additional positional or keyword args are passed to :func:`torch.allclose`. """ __tracebackhide__ = True try: assert_allclose(x, y, *args, **kwargs) return except AssertionError as e: raise AssertionError(str(e))
[docs]def assert_is_int_tensor(x: Tensor): r"""Asserts that the values of a floating point tensor are integers. This test is equivalent to ``torch.allclose(x, x.round())``. Args: x (torch.Tensor): The first tensor. y (torch.Tensor): The second tensor. """ __tracebackhide__ = True if not torch.allclose(x, x.round()): try: assert str(x) == str(x.round()) except AssertionError as e: raise AssertionError(str(e))
__all__ = [ "assert_has_gradient", "assert_zero_grad", "assert_is_int_tensor", "assert_in_training_mode", "assert_in_eval_mode", "assert_tensors_close", ]

© Copyright 2020, Scott Chase Waggener. Revision 6d81d6b9.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: 0.1.0rc2
Versions
latest
docs
0.1.0rc2
v0.1.0rc1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources