
Source code for combustion.nn.modules.bifpn

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Callable, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch import Tensor

from combustion.util import double, single, triple

class _BiFPNMeta(type):
    def __new__(cls, name, bases, dct):
        x = super().__new__(cls, name, bases, dct)
        if "3d" in name:
            x.Conv = nn.Conv3d
            x.BatchNorm = nn.BatchNorm3d
            x.MaxPool = nn.MaxPool3d
            x.Tuple = staticmethod(triple)
        elif "2d" in name:
            x.Conv = nn.Conv2d
            x.BatchNorm = nn.BatchNorm2d
            x.MaxPool = nn.MaxPool2d
            x.Tuple = staticmethod(double)
        elif "1d" in name:
            x.Conv = nn.Conv1d
            x.BatchNorm = nn.BatchNorm1d
            x.MaxPool = nn.MaxPool1d
            x.Tuple = staticmethod(single)
            raise RuntimeError(f"Metaclass: error processing name {cls.__name__}")
        return x

class _BiFPN_Level(nn.Module):
    __constants__ = ["epsilon"]

    def __init__(
        num_channels: int,
        conv: Optional[Callable[[int], nn.Module]] = None,
        pool: Optional[Callable[[int], nn.Module]] = None,
        epsilon: float = 1e-4,
        scale_factor: Union[int, Tuple[int, ...]] = 2,
        weight_2_count: int = 3,
        super(_BiFPN_Level, self).__init__()

        self.epsilon = float(epsilon)

        self.conv_up = conv(num_channels)
        self.conv_down = conv(num_channels)

        self.feature_up = nn.Upsample(scale_factor=scale_factor, mode="nearest")

        # Conv layers
        self.conv_up = conv(num_channels)
        self.conv_down = conv(num_channels)

        # Feature scaling layers
        self.feature_up = nn.Upsample(scale_factor=scale_factor, mode="nearest")
        self.feature_down = pool(kernel_size=scale_factor)

        self.weight_1 = nn.Parameter(torch.ones(2))
        self.weight_2 = nn.Parameter(torch.ones(weight_2_count))

    def forward(
        self, same_level: Tensor, previous_level: Optional[Tensor] = None, next_level: Optional[Tensor] = None
    ) -> Tensor:
        output: Tensor = same_level

        if previous_level is None and next_level is None:
            raise ValueError("previous_level and next_level cannot both be None")

        # input + higher level
        if next_level is not None:
            weight_1 = torch.relu(self.weight_1)
            weight_1 = weight_1 / (torch.sum(weight_1, dim=0) + self.epsilon)

            # weighted combination of current level and higher level
            output = self.conv_up(weight_1[0] * same_level + weight_1[1] * self.feature_up(next_level))

        # input + lower level + last bifpn level (if one exists)
        if previous_level is not None:
            weight_2 = torch.relu(self.weight_2)
            weight_2 = weight_2 / (torch.sum(weight_2, dim=0) + self.epsilon)

            if output is not None:
                # weight_2ed combination of current level, downward fpn output, lower level
                output = self.conv_down(
                    weight_2[0] * same_level + weight_2[1] * output + weight_2[2] * self.feature_down(previous_level)
            # special case for top of pyramid
                # weighted combination of current level, downward fpn output, lower level
                output = self.conv_down(weight_2[0] * same_level + weight_2[1] * self.feature_down(previous_level))

        return output

class _BiFPN(nn.Module):
    __constants__ = ["levels"]

    def __init__(
        num_channels: int,
        levels: int,
        kernel_size: Union[int, Tuple[int, ...]] = 3,
        stride: Union[int, Tuple[int, ...]] = 2,
        epsilon: float = 1e-4,
        bn_momentum: float = 0.9997,
        bn_epsilon: float = 4e-5,
        activation: nn.Module = torch.nn.ReLU(),
        if float(epsilon) <= 0.0:
            raise ValueError(f"epsilon must be float > 0, found {epsilon}")
        if int(num_channels) < 1:
            raise ValueError(f"num_channels must be int > 0, found {num_channels}")
        if int(levels) < 1:
            raise ValueError(f"levels must be int > 0, found {levels}")

        self.levels = levels
        kernel_size = self.Tuple(kernel_size)
        stride = self.Tuple(stride)
        padding = tuple([(kernel - 1) // 2 for kernel in kernel_size])

        def conv(num_channels):
            return nn.Sequential(
                self.Conv(num_channels, num_channels, kernel_size, padding=padding, groups=num_channels, bias=False),
                self.Conv(num_channels, num_channels, kernel_size=1, bias=False),
                self.BatchNorm(num_features=num_channels, momentum=bn_momentum, eps=bn_epsilon),

        level_modules = []
        for i in range(levels):
            weight_2_count = 3 if i > 0 else 2
            level_modules.append(_BiFPN_Level(num_channels, conv, self.MaxPool, epsilon, stride, weight_2_count))
        self.bifpn = nn.ModuleList(level_modules)

    def forward(self, inputs: List[Tensor]) -> List[Tensor]:
        outputs: List[Tensor] = []

        for i, layer in enumerate(self.bifpn):
            current_level = inputs[i]
            previous_level = inputs[i - 1] if i > 0 else None
            next_level = inputs[i + 1] if i < len(inputs) - 1 else None
            outputs.append(layer(current_level, previous_level, next_level))

        return outputs

class BiFPN1d(_BiFPN, metaclass=_BiFPNMeta):

[docs]class BiFPN2d(_BiFPN, metaclass=_BiFPNMeta): r"""A bi-directional feature pyramid network (BiFPN) used in the EfficientDet implementation (`EfficientDet Scalable and Efficient Object Detection`_). The bi-directional FPN mixes features at different resolution, while also capturing (via learnable weights) that features at different resolutions can contribute unequally to the desired output. Weights controlling the contribution of each FPN level are normalized using fast normalized fusion, which the authors note is more efficient than a softmax based fusion. It is ensured that for all weights, :math:`w_i > 0` by applying ReLU to each weight. The weight normalization is as follows .. math:: O = \sum_{i}\frac{w_i}{\epsilon + \sum_{j} w_j} \cdot I_i The structure of the block is as follows: .. image::*qH6d0kBU2cRxOkWUsfgDgg.png :width: 300px :align: center :height: 400px :alt: Diagram of BiFPN layer Args: num_channels (int): The number of channels in each feature pyramid level. All inputs :math:`P_i` should have ``num_channels`` channels, and outputs :math:`P_i'` will have ``num_channels`` channels. levels (int): The number of levels in the feature pyramid. Must have ``levels > 1``. kernel_size (int or tuple of ints): Choice of kernel size stride (int or tuple of ints): Controls the scaling used to upsample/downsample adjacent levels in the BiFPN. This stride is passed to :class:`torch.nn.MaxPool2d` and :class:`torch.nn.Upsample`. epsilon (float, optional): Small value used for numerical stability when normalizing weights via fast normalized fusion. Default ``1e-4``. bn_momentum (float, optional): Momentum for batch norm layers. bn_epsilon (float, optional): Epsilon for batch norm layers. activation (:class:`torch.nn.Module`): Activation function to use on convolution layers. Shape: - Inputs: List of Tensors of shape :math:`(N, *C, *H, *W)` where :math:`*C, *H, *W` indicates variable channel/height/width at each level of downsapling. - Output: Same shape as input. .. _EfficientDet Scalable and Efficient Object Detection: """
class BiFPN(nn.Module): def __init__(self, *args, **kwargs): super().__init__() self._bifpn = BiFPN2d(*args, **kwargs) import warnings warnings.warn("BiFPN is deprecated, please use BiFPN2d instead", category=DeprecationWarning) def forward(self, inputs: List[Tensor]) -> List[Tensor]: return self._bifpn(inputs) class BiFPN3d(_BiFPN, metaclass=_BiFPNMeta): pass

© Copyright 2020, Scott Chase Waggener. Revision cac3fb98.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.1.0rc1
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources