Shortcuts

Source code for combustion.points.transforms

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from math import cos, radians, sin
from typing import Iterable, Tuple

import torch
import torch.nn as nn
from torch import Tensor


@torch.jit.script
def rotate(
    coords: Tensor, x: float = 0.0, y: float = 0.0, z: float = 0.0, degrees: bool = False, return_matrix: bool = False
) -> Tensor:
    # validate inputs
    if coords.ndim > 3 or coords.ndim < 2:
        raise ValueError(f"Expected 2 <= coords.ndim <= 3 but coords.ndim == {coords.ndim}")
    if coords.shape[-1] != 3:
        raise ValueError(f"Expected coords.shape[-1] == 3 but found coords.shape[-1] == {coords.shape[-1]}")
    x, y, z = float(x), float(y), float(z)

    # add batch dim if not present
    original_shape = coords.shape
    coords = coords.view(-1, coords.shape[-2], coords.shape[-1]).float()
    output = torch.empty_like(coords)

    # degrees to radians if desired
    if degrees:
        x = radians(x)
        y = radians(y)
        z = radians(z)

    # build rotation matrices
    rot_x = torch.tensor([[1.0, 0.0, 0.0], [0.0, cos(x), -sin(x)], [0.0, sin(x), cos(x)]]).type_as(coords)
    rot_y = torch.tensor([[cos(y), 0.0, sin(y)], [0.0, 1.0, 0.0], [-sin(y), 0.0, cos(y)]]).type_as(coords)
    rot_z = torch.tensor([[cos(z), -sin(z), 0.0], [sin(z), cos(z), 0.0], [0.0, 0.0, 1.0]]).type_as(coords)
    rotation_matrix = torch.chain_matmul(rot_z, rot_x, rot_y).unsqueeze_(0)
    assert rotation_matrix.ndim == 3
    assert rotation_matrix.size() == torch.Size((1, 3, 3))

    if return_matrix:
        return rotation_matrix

    # perform rotation
    torch.bmm(coords, rotation_matrix, out=output)
    output = output.view(original_shape)

    return output


[docs]class Rotate(nn.Module): r"""Rotates a collection of points using rotation values in radians or degrees. Args: x (float): Rotation about x-axis y (float): Rotation about y-axis z (float): Rotation about z-axis degrees (bool): By default rotations are in radians. When ``degrees=True``, rotations are treated as degrees. Shape * ``coords`` - :math:`(B, N, 3)` or :math:`(N, 3)` * Output - same as ``coords`` """ def __init__(self, x: float = 0.0, y: float = 0.0, z: float = 0.0, degrees: bool = False): super().__init__() self.x = float(x) self.y = float(y) self.z = float(z) self.degrees = degrees def extra_repr(self): s = f"x={self.x}, y={self.y}, z={self.z}" if self.degrees: s += ", degrees=True" return s def forward(self, coords: Tensor) -> Tensor: return rotate(coords, self.x, self.y, self.z, self.degrees)
def random_rotate( coords: Tensor, x: Tuple[float, float] = (0.0, 0.0), y: Tuple[float, float] = (0.0, 0.0), z: Tuple[float, float] = (0.0, 0.0), degrees: bool = False, return_matrix: bool = False, ) -> Tensor: for var, s in zip((x, y, z), ("x", "y", "z")): if not isinstance(var, Iterable): raise TypeError(f"Expected {s} to be iterable, but found {type(var)}") if len(var) != 2: raise ValueError(f"Expected {s} to be of length 2, but found {len(var)}") if var[1] < var[0]: raise ValueError(f"Expected {s}_low <= {s}_high, but found {(var[0], var[1])}") # generate random rotation _ = torch.tensor([[x[0], x[1]], [y[0], y[1]], [z[0], z[1]]]).type_as(coords).float() lows = _.min(dim=-1).values highs = _.max(dim=-1).values rots = torch.rand_like(highs) rots.mul_(highs - lows).add_(lows) return rotate(coords, rots[0], rots[1], rots[2], degrees, return_matrix)
[docs]class RandomRotate(nn.Module): r"""Rotates a collection of points randomly between a minimum and maximum possible rotation. Args: x (tuple of floats): Minimum and maximum rotation about x-axis. y (tuple of floats): Minimum and maximum rotation about y-axis. z (tuple of floats): Minimum and maximum rotation about z-axis. degrees (bool): By default rotations are in radians. When ``degrees=True``, rotations are treated as degrees. Shape * ``coords`` - :math:`(B, N, 3)` or :math:`(N, 3)` * Output - same as ``coords`` """ def __init__( self, x: Tuple[float, float] = (0.0, 0.0), y: Tuple[float, float] = (0.0, 0.0), z: Tuple[float, float] = (0.0, 0.0), degrees: bool = False, ): super().__init__() for var, s in zip((x, y, z), ("x", "y", "z")): if not isinstance(var, Iterable): raise TypeError(f"Expected {s} to be iterable, but found {type(var)}") if len(var) != 2: raise ValueError(f"Expected {s} to be of length 2, but found {len(var)}") if var[1] < var[0]: raise ValueError(f"Expected {s}_low <= {s}_high, but found {(var[0], var[1])}") self.x = x self.y = y self.z = z self.degrees = degrees def extra_repr(self): s = f"x={self.x}, y={self.y}, z={self.z}" if self.degrees: s += ", degrees=True" return s def forward(self, coords: Tensor) -> Tensor: return random_rotate(coords, self.x, self.y, self.z, self.degrees)

© Copyright 2020, Scott Chase Waggener. Revision 6d81d6b9.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: 0.1.0rc2
Versions
latest
docs
0.1.0rc2
v0.1.0rc1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources