Shortcuts

Source code for combustion.nn.modules.raspp

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import List, Tuple, Union

import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

# from combustion.nn import HardSigmoid
from combustion.util import double, single, triple


class _RASPPMeta(type):
    def __new__(cls, name, bases, dct):
        x = super().__new__(cls, name, bases, dct)
        if "3d" in name:
            x.Conv = nn.Conv3d
            x.BatchNorm = nn.BatchNorm3d
            x.AvgPool = nn.AvgPool3d
            x.Tuple = staticmethod(triple)
        elif "2d" in name:
            x.Conv = nn.Conv2d
            x.BatchNorm = nn.BatchNorm2d
            x.AvgPool = nn.AvgPool2d
            x.Tuple = staticmethod(double)
        elif "1d" in name:
            x.Conv = nn.Conv1d
            x.BatchNorm = nn.BatchNorm1d
            x.AvgPool = nn.AvgPool1d
            x.Tuple = staticmethod(single)
        else:
            raise RuntimeError(f"Metaclass: error processing name {cls.__name__}")
        return x


class _RASPPLite(nn.Module):
    def __init__(
        self,
        input_filters: int,
        residual_filters: int,
        output_filters: int,
        num_classes: int,
        pool_kernel: Union[int, Tuple[int, ...]] = 42,
        pool_stride: Union[int, Tuple[int, ...]] = 18,
        dilation: Union[int, Tuple[int, ...]] = 1,
        sigmoid: nn.Module = nn.Sigmoid(),
        relu: nn.Module = nn.ReLU(),
        bn_momentum: float = 0.1,
        bn_epsilon: float = 1e-5,
        final_upsample: int = 1,
    ):
        super().__init__()
        pool_kernel = self.Tuple(pool_kernel)
        pool_stride = self.Tuple(pool_stride)
        dilation = self.Tuple(dilation)

        self.pooled = nn.Sequential(
            self.AvgPool(pool_kernel, stride=pool_stride),
            self.Conv(input_filters, output_filters, kernel_size=1, stride=pool_stride),
            sigmoid,
        )

        self.main_conv1 = nn.Sequential(
            self.Conv(input_filters, output_filters, kernel_size=1, bias=False),
            self.BatchNorm(output_filters, momentum=bn_momentum, eps=bn_epsilon),
            relu,
        )

        self.residual_conv = self.Conv(residual_filters, num_classes, kernel_size=1)
        self.main_conv2 = self.Conv(output_filters, num_classes, kernel_size=1)

        upsample = []
        for i in range(final_upsample // 2):
            upsample.append(nn.ConvTranspose2d(num_classes, num_classes, kernel_size=2, stride=2))

        if upsample:
            self.final_upsample = nn.ModuleList(*upsample)
        else:
            self.final_upsample = None

    def forward(self, inputs: List[Tensor]) -> Tensor:
        skip_path, main_path = inputs
        residual = self.residual_conv(skip_path)

        pooled = self.pooled(main_path)
        main = self.main_conv1(main_path)

        # upsample for main to match pooled
        upsample_shape: List[int] = []
        for i, size in enumerate(main.shape):
            if i >= 2:
                upsample_shape.append(size)
        pooled = F.interpolate(pooled, size=upsample_shape, mode="bilinear")

        main = main * pooled

        # upsample for main to match residual
        upsample_shape: List[int] = []
        for i, size in enumerate(residual.shape):
            if i >= 2:
                upsample_shape.append(size)
        main = F.interpolate(main, size=upsample_shape, mode="bilinear")

        main = self.main_conv2(main)
        output = main + residual

        if self.final_upsample is not None:
            output = self.final_upsample(output)

        return output


[docs]class RASPPLite2d(_RASPPLite, metaclass=_RASPPMeta): r"""Implements the a lite version of the reduced atrous spatial pyramid pooling module (R-ASPP Lite) described in `Searching for MobileNetV3`_. This is a semantic segmentation head. .. image:: ./raspp.png :width: 800px :align: center :height: 300px :alt: Diagram of R-ASPP Lite. Args: input_filters (int): Number of input channels along the main pathway residual_filters (int): Number of input channels along the residual pathway output_filters (int): Number of channels in the middle of the segmentation head. num_classes (int): Number of classes for semantic segmentation pool_kernel (int or tuple of ints): Size of the average pooling kernel pool_stride (int or tuple of ints): Stride of the average pooling kernel dilation (int or tuple of ints): Dilation of the atrous convolution. Defaults to ``1``, meaning no atrous convolution. sigmoid (:class:`torch.nn.Module`): Activation function to use along the pooled pathway relu (:class:`torch.nn.Module`): Activation function to use along the main convolutional pathway bn_momentum (float): Batch norm momentum bn_epsilon (float): Batch norm epsilon final_upsample (int): An optional amount of additional to be applied via transposed convolutions. It is expected that additional upsampling is a power of two. .. _Searching for MobileNetV3: https://arxiv.org/abs/1905.02244 """
class RASPPLite1d(_RASPPLite, metaclass=_RASPPMeta): pass class RASPPLite3d(_RASPPLite, metaclass=_RASPPMeta): pass

© Copyright 2020, Scott Chase Waggener. Revision 6d81d6b9.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: 0.1.0rc2
Versions
latest
docs
0.1.0rc2
v0.1.0rc1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources