Shortcuts

Source code for combustion.models.mobile_unet

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from copy import deepcopy
from typing import List, Optional

import torch
import torch.nn as nn
from torch import Tensor

from combustion.nn import MatchShapes, MobileNetBlockConfig


class _MobileUnetMeta(type):
    def __new__(cls, name, bases, dct):
        x = super().__new__(cls, name, bases, dct)
        if "3d" in name:
            x.Conv = nn.Conv3d
            x.ConvTranspose = nn.ConvTranspose3d
            x.BatchNorm = nn.BatchNorm3d
            x._get_blocks = MobileNetBlockConfig.get_3d_blocks
        elif "2d" in name:
            x.Conv = nn.Conv2d
            x.ConvTranspose = nn.ConvTranspose2d
            x.BatchNorm = nn.BatchNorm2d
            x._get_blocks = MobileNetBlockConfig.get_2d_blocks
        elif "1d" in name:
            x.Conv = nn.Conv1d
            x.ConvTranspose = nn.ConvTranspose1d
            x.BatchNorm = nn.BatchNorm1d
            x._get_blocks = MobileNetBlockConfig.get_1d_blocks
        else:
            raise RuntimeError(f"Metaclass: error processing name {cls.__name__}")
        return x


class _MobileUnet(nn.Module):
    def __init__(
        self,
        down_configs: List[MobileNetBlockConfig],
        up_configs: Optional[List[MobileNetBlockConfig]] = None,
        stem: Optional[nn.Module] = None,
        head: Optional[nn.Module] = None,
    ):
        super().__init__()
        down_configs = deepcopy(down_configs)
        if up_configs is not None:
            up_configs = deepcopy(up_configs)
        else:
            up_configs = []
            for config in reversed(down_configs):
                reversed_config = deepcopy(config)
                reversed_config.input_filters = config.output_filters
                reversed_config.output_filters = config.input_filters
                up_configs.append(reversed_config)

        # stem / head if provided
        self.stem = stem if stem is not None else nn.Identity()
        self.head = head if head is not None else nn.Identity()

        # MobileNetV3 convolution blocks for downsampling
        blocks = []
        for config in down_configs:
            conv_block = self.__class__._get_blocks(config)
            blocks.append(conv_block)
        self.down_blocks = nn.ModuleList(blocks)

        # MobileNetV3 convolution blocks for upsampling
        blocks = []
        for i, config in enumerate(up_configs):
            stride = config.stride

            in_channels = config.input_filters
            out_channels = config.output_filters
            config.output_filters = in_channels

            # special case for first up level with no skip conn
            if i != 0:
                config.input_filters *= 2

            config.stride = 1
            conv_block = nn.Sequential(
                self.__class__._get_blocks(config),
                self.ConvTranspose(in_channels, out_channels, kernel_size=2, stride=stride),
            )
            blocks.append(conv_block)
        self.up_blocks = nn.ModuleList(blocks)

        self.match_shapes = MatchShapes(strategy="crop")

    def forward(self, inputs: Tensor) -> Tensor:
        _ = self.stem(inputs)

        # downsampling levels
        skip_conns: List[Tensor] = [_]
        for down_level in self.down_blocks:
            _ = down_level(_)
            skip_conns.append(_)
        del skip_conns[-1]

        for i, up_level in enumerate(self.up_blocks):
            # upsample
            _ = up_level(_)

            # match skip conn shape and cat
            skip_conn = skip_conns[-(i + 1)]
            spatial_shape = skip_conn.shape[2:]
            _ = self.match_shapes([_], spatial_shape)[0]
            _ = torch.cat([_, skip_conn], dim=1)

        _ = self.head(_)
        return _

    @classmethod
    def from_identical_blocks(
        cls,
        block: MobileNetBlockConfig,
        in_channels: int,
        levels: List[int],
        stem: Optional[nn.Module] = None,
        head: Optional[nn.Module] = None,
    ) -> "_MobileUnet":
        down_blocks: List[MobileNetBlockConfig] = []
        for i in levels:
            new_block = deepcopy(block)
            new_block.num_repeats = i
            new_block.input_filters = in_channels
            new_block.output_filters = in_channels * 2
            new_block.stride = 2
            down_blocks.append(new_block)
            in_channels *= 2

        return cls(down_blocks, stem=stem, head=head)


class MobileUnet1d(_MobileUnet, metaclass=_MobileUnetMeta):
    pass


[docs]class MobileUnet2d(_MobileUnet, metaclass=_MobileUnetMeta): r"""Modified implementation of U-Net as described in the `U-Net paper`_. This implementation uses MobileNetV3 inverted bottleneck blocks (from `Searching for MobileNetV3`_) as the fundamental building block of each convolutional layer. Automatic padding/cropping is used to ensure operation with an input of arbitrary spatial shape. A general U-Net architecture is as follows: .. image:: unet.png :width: 800px :align: center :height: 500px :alt: Diagram of BiFPN layer Args: down_configs (list of :class:`combustion.nn.MobileNetBlockConfig`) Configs for each of the :class:`combustion.nn.MobileNetConvBlock2d` blocks used in the downsampling portion of the model. up_configs (optional, list of :class:`combustion.nn.MobileNetBlockConfig`) Configs for each of the :class:`combustion.nn.MobileNetConvBlock2d` blocks used in the upsampling portion of the model. By default, the reverse ``down_configs`` is used with as-needed modifications. stem (optional, :class:`torch.nn.Module`): An stem/tail layer. head (optional, :class:`torch.nn.Module`): An optional head layer Shapes * Input: :math:`(N, C, H, W)` * Output: List of tensors of shape :math:`(N, C, H', W')`, where height and width vary depending on the amount of downsampling for that feature map. .. _U-Net paper: https://arxiv.org/abs/1505.04597 .. _Searching for MobileNetV3: https://arxiv.org/abs/1905.02244 """
class MobileUnet3d(_MobileUnet, metaclass=_MobileUnetMeta): pass

© Copyright 2020, Scott Chase Waggener. Revision 6d81d6b9.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: 0.1.0rc2
Versions
latest
docs
0.1.0rc2
v0.1.0rc1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources