Shortcuts

Source code for combustion.nn.loss.focal

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor


_EPSILON = 1e-5


[docs]def focal_loss( input: Tensor, target: Tensor, gamma: float, pos_weight: Optional[float] = None, label_smoothing: Optional[float] = None, reduction: str = "mean", normalize: bool = False, ): r"""Computes the Focal Loss between input and target. See :class:`FocalLoss` for more details Args: input (torch.Tensor): The predicted values on the interval :math:`[0, 1]`. target (torch.Tensor): The target values on the interval :math:`[0, ``]`. gamma (float): The focusing parameter :math:`\gamma`. Must be non-negative. pos_weight (float, optional): The positive weight coefficient :math:`\alpha` to use on the positive examples. Must be non-negative. label_smoothing (float, optional): Float in [0, 1]. When 0, no smoothing occurs. When positive, the binary ground truth labels are clamped to :math:`[p, 1-p]`. reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` normalize (bool, optional): If given, output loss will be divided by the number of positive elements in ``target``. """ positive_indices = target == 1 with torch.no_grad(): target = target.clone().detach() if pos_weight is not None: alpha = torch.empty_like(input).fill_(1 - pos_weight) alpha[positive_indices] = pos_weight else: alpha = torch.ones_like(input) if label_smoothing: target.clamp_(label_smoothing, 1.0 - label_smoothing) # compute loss p = input pt = torch.where(target == 1, p, 1 - p) ce_loss = F.binary_cross_entropy(input, target, reduction="none") loss = alpha * torch.pow(1 - pt, gamma) * ce_loss # normalize if normalize: num_positive_examples = positive_indices.sum().clamp_(min=1) loss.div_(num_positive_examples) if reduction == "mean": loss = loss.mean() if reduction == "sum": loss = loss.sum() return loss
[docs]def focal_loss_with_logits( input: Tensor, target: Tensor, gamma: float, pos_weight: Optional[float] = None, label_smoothing: Optional[float] = None, reduction: str = "mean", normalize: bool = False, ): r"""Computes the Focal Loss between input and target. See :class:`FocalLossWithLogits` for more details Args: input (torch.Tensor): The predicted values. target (torch.Tensor): The target values. gamma (float): The focusing parameter :math:`\gamma`. Must be non-negative. pos_weight (float, optional): The positive weight coefficient :math:`\alpha` to use on the positive examples. Must be non-negative. label_smoothing (float, optional): Float in [0, 1]. When 0, no smoothing occurs. When positive, the binary ground truth labels are clamped to :math:`[p, 1-p]`. reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` normalize (bool, optional): If given, output loss will be divided by the number of positive elements in ``target``. """ positive_indices = target == 1 # apply label smoothing, clamping true labels between x, 1-x if label_smoothing is not None: target = target.clamp(label_smoothing, 1.0 - label_smoothing) # loss in paper can be expressed as # alpha * (1 - pt) ** gamma * (BCE loss) # where pt = p if target == 1 else (1-p) # calculate p, pt, and vanilla BCELoss # NOTE BCE loss gets logits input, NOT p=sigmoid(input) calulated below ce_loss = F.binary_cross_entropy_with_logits(input, target, reduction="none") # Therefore one unified form for positive (z = 1) and negative (z = 0) # samples is: # (1 - p_t)^r = exp(-r * z * x - r * log(1 + exp(-x))). neg_logits = input.neg() if gamma != 0: focal_term = torch.exp(gamma * (target * neg_logits - neg_logits.exp().log1p())) loss = focal_term * ce_loss else: loss = ce_loss if pos_weight is not None: loss = torch.where(positive_indices, pos_weight * loss, (1.0 - pos_weight) * loss) # normalize if normalize: num_positive_examples = positive_indices.sum().clamp_(min=1) loss.div_(num_positive_examples) if reduction == "mean": loss = loss.mean() if reduction == "sum": loss = loss.sum() return loss
class _FocalLoss(nn.Module): _loss = focal_loss def __init__(self, gamma, pos_weight=None, label_smoothing=None, reduction: str = "mean", normalize: bool = False): super(_FocalLoss, self).__init__() self.gamma = gamma self.alpha = pos_weight self.label_smoothing = label_smoothing self.reduction = reduction self.normalize = normalize @classmethod def from_args(cls, args): return cls(args.focal_gamma, args.focal_alpha, args.focal_smoothing, args.reduction,) def forward(self, input: Tensor, target: Tensor) -> Tensor: """forward Calculate smoothed MSE loss between input and target. Expects inputs of shape NxCxHxW. :param input: The predicted outputs :type input: Tensor :param target: The target outputs :type target: Tensor :rtype: Tensor """ loss = self.__class__._loss( input, target, self.gamma, self.alpha, self.label_smoothing, self.reduction, self.normalize ) return loss
[docs]class FocalLoss(_FocalLoss): r"""Creates a criterion that measures the Focal Loss between the target and the output. Focal loss is described in the paper `Focal Loss For Dense Object Detection`_. Args: gamma (float): The focusing parameter :math:`\gamma`. Must be non-negative. pos_weight (float, optional): The positive weight coefficient :math:`\alpha` to use on the positive examples. Must be non-negative. label_smoothing (float, optional): Float in [0, 1]. When 0, no smoothing occurs. When positive, the binary ground truth labels are clamped to :math:`[p, 1-p]`. reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` normalize (bool, optional): If given, output loss will be divided by the number of positive elements in ``target``. Shape: - Input: :math:`(N, *)` where :math:`*` means, any number of additional dimensions - Target: :math:`(N, *)`, same shape as the input - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same shape as input. Examples:: >>> loss = FocalLoss(gamma=1.0, pos_weight=0.8) >>> pred = torch.rand(10, 10, requires_grad=True) >>> target = torch.rand(10, 10).round() >>> output = loss(pred, target) .. _Focal Loss For Dense Object Detection: https://arxiv.org/abs/1708.02002 """ _loss = focal_loss
[docs]class FocalLossWithLogits(_FocalLoss): r"""Creates a criterion that measures the Focal Loss between the target and the output. Focal loss is described in the paper `Focal Loss For Dense Object Detection`_. Inputs are expected to be logits (i.e. not already scaled to the interval :math:`[0, 1]` through a sigmoid or softmax). This computation on logits is more numerically stable and efficient for reverse mode auto-differentiation and should be preferred for that use case. Args: gamma (float): The focusing parameter :math:`\gamma`. Must be non-negative. pos_weight (float, optional): The positive weight coefficient :math:`\alpha` to use on the positive examples. Must be non-negative. label_smoothing (float, optional): Float in [0, 1]. When 0, no smoothing occurs. When positive, the binary ground truth labels are clamped to :math:`[p, 1-p]`. reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` normalize (bool, optional): If given, output loss will be divided by the number of positive elements in ``target``. Shape: - Input: :math:`(N, *)` where :math:`*` means, any number of additional dimensions - Target: :math:`(N, *)`, same shape as the input - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same shape as input. Examples:: >>> loss = FocalLoss(gamma=1.0, pos_weight=0.8) >>> pred = torch.rand(10, 10, requires_grad=True) >>> target = torch.rand(10, 10).round() >>> output = loss(pred, target) .. _Focal Loss For Dense Object Detection: https://arxiv.org/abs/1708.02002 """ _loss = focal_loss_with_logits
__all__ = [ "focal_loss_with_logits", "focal_loss", "FocalLoss", "FocalLossWithLogits", ]

© Copyright 2020, Scott Chase Waggener. Revision cac3fb98.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.1.0rc1
Versions
latest
docs
v0.1.0rc1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources