
Source code for combustion.nn.activations.swish

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

# implementation inspired by

class _SwishFunction(torch.autograd.Function):
    def forward(ctx, i):
        result = i * torch.sigmoid(i)
        return result

    def backward(ctx, grad_output):
        for i in ctx.saved_tensors:
            sigmoid_i = torch.sigmoid(i)
        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))

[docs]def swish(inputs: Tensor, memory_efficient: bool = True) -> Tensor: r"""The swish activation function, defined as .. math:: f(x) = x \cdot \text{sigmoid}(x) Args: inputs (Tensor): The input tensor memory_efficient (bool, optional): Whether or not to use an implementation that is more memory efficient at training time. When ``memory_efficient=True``, this method is incompatible with TorchScript. .. warning:: This method is traceable with TorchScript when ``memory_efficient=False``, but is un-scriptable due to the use of :class:`torch.autograd.Function` for a memory-efficient backward pass. Please export using :func:`torch.jit.trace` with ``memory_efficient=False`` """ if memory_efficient: return _SwishFunction.apply(inputs) else: return inputs * torch.sigmoid(inputs)
[docs]class Swish(nn.Module): r"""The swish activation function, defined as .. math:: f(x) = x \cdot \text{sigmoid}(x) .. warning:: This method is traceable with TorchScript, but is un-scriptable due to the use of :class:`torch.autograd.Function` for a memory-efficient backward pass. Please export using :func:`torch.jit.trace` after calling ``module.eval()``. """ @torch.jit.ignore def _memory_efficient_forward(self, inputs: Tensor) -> Tensor: return swish(inputs) def forward(self, inputs: Tensor) -> Tensor: if not return self._memory_efficient_forward(inputs) else: return inputs * torch.sigmoid(inputs)
[docs]def hard_swish(inputs: Tensor, inplace: bool = False) -> Tensor: r"""The hard swish activation function proposed in `Searching For MobileNetV3`_, defined as .. math:: f(x) = x \cdot \frac{\text{ReLU6}(x + 3)}{6} Hard swish approximates the swish activation, but computationally cheaper due to the removal of :math:`\text{sigmoid}(x)`. Args: inputs (Tensor): The input tensor inplace (bool, optional): Whether or not to perform the operation in place. .. _Searching for MobileNetV3: """ if inplace: return inputs.mul_(F.relu6(inputs + 3, inplace=True).div_(6)) else: return F.relu6(inputs + 3).div(6).mul(inputs)
[docs]class HardSwish(nn.Module): r"""The hard swish activation function proposed in `Searching For MobileNetV3`_, defined as .. math:: f(x) = x \cdot \frac{\text{ReLU6}(x + 3)}{6} Hard swish approximates the swish activation, but computationally cheaper due to the removal of :math:`\text{sigmoid}(x)`. .. image:: ./hswish.png :width: 600px :align: center :height: 300px :alt: Comparison of Hard Swish and Swish activations. Args: inplace (bool, optional): Whether or not to perform the operation in place. .. _Searching for MobileNetV3: """ def __init__(self, inplace: bool = False): super().__init__() self.inplace = inplace def extra_repr(self): if self.inplace: return "inplace=True" else: return "" def forward(self, inputs: Tensor) -> Tensor: return hard_swish(inputs, self.inplace)

© Copyright 2020, Scott Chase Waggener. Revision 6d81d6b9.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: 0.1.0rc2
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources