Shortcuts

Source code for combustion.vision.bbox

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Dict, Optional, Tuple, Union

import torch
from numpy import ndarray
from torch import Tensor


try:
    import cv2
except ImportError:

    class cv2:
        def __getattr__(self, attr):
            raise ImportError(
                "Bounding box visualization requires cv2. "
                "Please install combustion with 'vision' extras using "
                "pip install combustion [vision]"
            )

        def __setattr__(self, attr):
            raise ImportError(
                "Bounding box visualization requires cv2. "
                "Please install combustion with 'vision' extras using "
                "pip install combustion [vision]"
            )


CORRECT_BOX_COLOR = BOX_COLOR = (255, 0, 0)
TEXT_COLOR = (255, 255, 255)


[docs]def visualize_bbox( img: Union[Tensor, ndarray], bbox: Optional[Union[Tensor, ndarray]] = None, label: Optional[Union[Tensor, ndarray]] = None, scores: Optional[Union[Tensor, ndarray]] = None, class_names: Optional[Dict[int, str]] = None, box_color: Tuple[int, int, int] = (255, 0, 0), text_color: Tuple[int, int, int] = (255, 255, 255), label_alpha: float = 0.4, thickness: int = 2, ) -> Union[Tensor, ndarray]: r"""Adds bounding box visualization to an input array Args: img (Tensor or numpy.ndarray): The image to draw anchor boxes on. bbox (Tensor or numpy.ndarray, optional): The anchor boxes to draw label (Tensor or numpy.ndarray, optional): Class labels associated with each anchor box scores (Tensor or numpy.ndarray, optional): Class scores associated with each anchor box class_names (dict, optional): Dictionary mapping integer class labels to string names. If ``label`` is supplied but ``class_names`` is not, integer class labels will be used. box_color (tuple of ints, optional): A 3-tuple giving the RGB color value to use for anchor boxes. text_color (tuple of ints, optional): A 3-tuple giving the RGB color value to use for labels. label_alpha (float, optional): Alpha to apply to the colored background for class labels. thickness (int, optional): Specifies the thickness of anchor boxes. Returns: :class:`torch.Tensor` or :class:`numpy.ndarray` (depending on what was given for `img`) with the output image. """ original_img_type = type(img) img: Tensor = _check_input(img, "img", (2, 3)) bbox: Optional[Tensor] = _check_input(bbox, "bbox", 2, (None, 4)) label: Optional[Tensor] = _check_input(label, "label", 2, (None, 1)) scores: Optional[Tensor] = _check_input(scores, "scores", 2, (None, 1)) # permute to channels last if img.ndim == 3: img = img.permute(1, 2, 0) else: img = img.unsqueeze(-1) img, bbox, label = [x.cpu().numpy() if x is not None else None for x in (img, bbox, label)] # convert grayscale input to color for bounding boxes if img.shape[-1] < 3: img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) if bbox is not None: for i, coords in enumerate(bbox): x_min, y_min, x_max, y_max = [int(c) for c in coords] # bounding box cv2.rectangle( img, (x_min, y_min), (x_max, y_max), box_color, thickness, ) if label is not None: cls = label[i].item() if class_names is not None: class_name = class_names[cls] else: class_name = f"Class {cls}" if scores is not None: class_name += f" - {scores[i].item():0.3f}" # tag bounding box with class name / integer id ((text_width, text_height), _) = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1) cv2.rectangle(img, (x_min, y_min - int(1.3 * text_height)), (x_min + text_width, y_min), box_color, -1) cv2.putText( img, class_name, (x_min, y_min - int(0.3 * text_height)), cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, lineType=cv2.LINE_AA, ) # restore original data type with channels first ordering channel, height, width = -1, 0, 1 if original_img_type == Tensor: img = torch.from_numpy(img) img = img.permute(channel, height, width) else: img = img.transpose(channel, height, width) return img
def _check_input(x, name, ndim=None, shape=None): if name != "img" and x is None: return None elif x is None: raise ValueError("img cannot be None") if not isinstance(x, (Tensor, ndarray)): raise TypeError(f"Expected Tensor or np.ndarray for {name}, found {type(x)}") if isinstance(x, ndarray): x = torch.from_numpy(x) if ndim is not None: if isinstance(ndim, int): if x.ndim != ndim: raise ValueError(f"Expected {name}.ndim = {ndim}, found {x.ndim}") elif x.ndim < ndim[0] or x.ndim > ndim[1]: raise ValueError(f"Expected {ndim[0]} <= {name}.ndim = {ndim[1]}, found {x.ndim}") if shape is not None: for i, dim in enumerate(shape): if dim is not None and x.shape[i] != dim: raise ValueError(f"Expected {name}.shape = {shape}, found {x.shape}") return x

© Copyright 2020, Scott Chase Waggener. Revision cac3fb98.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.1.0rc1
Versions
latest
docs
v0.1.0rc1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources