Shortcuts

Source code for combustion.nn.modules.preprocessing

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Tuple, Union

import torch
import torch.nn as nn
from torch import Tensor


[docs]class Standardize(nn.Module): r"""Standardizes an input tensor to zero mean unit variance along one or more dimensions. Mean and variance will be computed over the selected dimensions, and the resultant tensor will be computed as .. math:: x_o = \frac{x_i - \mu}{\max(\sigma^2, \epsilon)} Args: dims (int or tuple of ints): The dimension(s) to standardize over epsilon (float, optional): Lower bound on variance unbiased (bool, optional): Whether or not to used unbiased estimation in variance calculation. See :func:`torch.var_mean` for more details. Shape: - Inputs: Tensor of shape :math:`(*)` where :math:`*` indicates an arbitrary number of dimensions. - Output: Same shape as input. """ def __init__(self, dims: Union[int, Tuple[int]], epsilon: float = 1e-9, unbiased: bool = True): super(Standardize, self).__init__() if isinstance(dims, int): dims = (dims,) self.dims = set([int(x) for x in dims]) self.epsilon = abs(float(epsilon)) self.unbiased = bool(unbiased) def __repr__(self): s = f"Standardize(dims={tuple(self.dims)}" if self.epsilon != 1e-9: s += f", epsilon={self.epsilon}" if not self.unbiased: s += f", unbiased={self.unbiased}" s += ")" return s
[docs] def forward(self, inputs: Tensor) -> Tensor: r""" Args: inputs (Tensor): The tensor to be standardized """ for dim in self.dims: if abs(dim) >= inputs.ndim: raise ValueError(f"Invalid dim {dim} for input of shape {inputs.shape}") var, mean = torch.var_mean(inputs, dim=tuple(self.dims), keepdim=True, unbiased=self.unbiased) var.clamp_(min=self.epsilon) result = inputs.sub(mean).div_(var) return result

© Copyright 2020, Scott Chase Waggener. Revision cac3fb98.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.1.0rc1
Versions
latest
docs
v0.1.0rc1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources