
Source code for

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Dict, Optional, Tuple, Union

import torch
from numpy import ndarray
from torch import Tensor

    import cv2
except ImportError:

    class cv2:
        def __getattr__(self, attr):
            raise ImportError(
                "Bounding box visualization requires cv2. "
                "Please install combustion with 'vision' extras using "
                "pip install combustion [vision]"

        def __setattr__(self, attr):
            raise ImportError(
                "Bounding box visualization requires cv2. "
                "Please install combustion with 'vision' extras using "
                "pip install combustion [vision]"

TEXT_COLOR = (255, 255, 255)

[docs]def visualize_bbox( img: Union[Tensor, ndarray], bbox: Optional[Union[Tensor, ndarray]] = None, label: Optional[Union[Tensor, ndarray]] = None, scores: Optional[Union[Tensor, ndarray]] = None, class_names: Optional[Dict[int, str]] = None, box_color: Tuple[int, int, int] = (255, 0, 0), text_color: Tuple[int, int, int] = (255, 255, 255), label_alpha: float = 0.4, thickness: int = 2, ) -> Union[Tensor, ndarray]: r"""Adds bounding box visualization to an input array Args: img (Tensor or numpy.ndarray): The image to draw anchor boxes on. bbox (Tensor or numpy.ndarray, optional): The anchor boxes to draw label (Tensor or numpy.ndarray, optional): Class labels associated with each anchor box scores (Tensor or numpy.ndarray, optional): Class scores associated with each anchor box class_names (dict, optional): Dictionary mapping integer class labels to string names. If ``label`` is supplied but ``class_names`` is not, integer class labels will be used. box_color (tuple of ints, optional): A 3-tuple giving the RGB color value to use for anchor boxes. text_color (tuple of ints, optional): A 3-tuple giving the RGB color value to use for labels. label_alpha (float, optional): Alpha to apply to the colored background for class labels. thickness (int, optional): Specifies the thickness of anchor boxes. Returns: :class:`torch.Tensor` or :class:`numpy.ndarray` (depending on what was given for `img`) with the output image. """ original_img_type = type(img) img: Tensor = _check_input(img, "img", (2, 3)) bbox: Optional[Tensor] = _check_input(bbox, "bbox", 2, (None, 4)) label: Optional[Tensor] = _check_input(label, "label", 2, (None, 1)) scores: Optional[Tensor] = _check_input(scores, "scores", 2, (None, 1)) # permute to channels last if img.ndim == 3: img = img.permute(1, 2, 0) else: img = img.unsqueeze(-1) img, bbox, label = [x.cpu().numpy() if x is not None else None for x in (img, bbox, label)] # convert grayscale input to color for bounding boxes if img.shape[-1] < 3: img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) if bbox is not None: for i, coords in enumerate(bbox): x_min, y_min, x_max, y_max = [int(c) for c in coords] # bounding box cv2.rectangle( img, (x_min, y_min), (x_max, y_max), box_color, thickness, ) if label is not None: cls = label[i].item() if class_names is not None: class_name = class_names[cls] else: class_name = f"Class {cls}" if scores is not None: class_name += f" - {scores[i].item():0.3f}" # tag bounding box with class name / integer id ((text_width, text_height), _) = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1) cv2.rectangle(img, (x_min, y_min - int(1.3 * text_height)), (x_min + text_width, y_min), box_color, -1) cv2.putText( img, class_name, (x_min, y_min - int(0.3 * text_height)), cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, lineType=cv2.LINE_AA, ) # restore original data type with channels first ordering channel, height, width = -1, 0, 1 if original_img_type == Tensor: img = torch.from_numpy(img) img = img.permute(channel, height, width) else: img = img.transpose(channel, height, width) return img
def _check_input(x, name, ndim=None, shape=None): if name != "img" and x is None: return None elif x is None: raise ValueError("img cannot be None") if not isinstance(x, (Tensor, ndarray)): raise TypeError(f"Expected Tensor or np.ndarray for {name}, found {type(x)}") if isinstance(x, ndarray): x = torch.from_numpy(x) if ndim is not None: if isinstance(ndim, int): if x.ndim != ndim: raise ValueError(f"Expected {name}.ndim = {ndim}, found {x.ndim}") elif x.ndim < ndim[0] or x.ndim > ndim[1]: raise ValueError(f"Expected {ndim[0]} <= {name}.ndim = {ndim[1]}, found {x.ndim}") if shape is not None: for i, dim in enumerate(shape): if dim is not None and x.shape[i] != dim: raise ValueError(f"Expected {name}.shape = {shape}, found {x.shape}") return x

© Copyright 2020, Scott Chase Waggener. Revision cac3fb98.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.1.0rc1
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources