Source code for todd.tasks.object_detection.bboxes

# pylint: disable=invalid-name

__all__ = [
    'BBox',
    'BBoxes',
    'BBoxesXY__',
    'BBoxesCXCY__',
    'BBoxes__XY',
    'BBoxes__WH',
    'BBoxesXYXY',
    'BBoxesXYWH',
    'BBoxesCXCYWH',
    'FlattenBBoxesMixin',
    'FlattenBBoxesXYXY',
    'FlattenBBoxesXYWH',
    'FlattenBBoxesCXCYWH',
]

from abc import ABC, abstractmethod
from typing import TypeVar
from typing_extensions import Self

import einops.layers.torch
import torch

from ..utils import FlattenMixin, NormalizeMixin, TensorWrapper
from .registries import ODBBoxesRegistry

BBox = tuple[float, float, float, float]
T = TypeVar('T', bound='BBoxes')


[docs] class BBoxes(NormalizeMixin[BBox], TensorWrapper[BBox], ABC): OBJECT_DIMENSIONS = 1
[docs] @classmethod def to_object(cls, tensor: torch.Tensor) -> BBox: return tuple(tensor.tolist())
[docs] def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) assert self._tensor.shape[-1] == 4
def _scale(self, ratio_xy: tuple[float, ...], /) -> torch.Tensor: return self._tensor * self._tensor.new_tensor(ratio_xy * 2) @property @abstractmethod def left(self) -> torch.Tensor: pass @property @abstractmethod def right(self) -> torch.Tensor: pass @property @abstractmethod def top(self) -> torch.Tensor: pass @property @abstractmethod def bottom(self) -> torch.Tensor: pass @property @abstractmethod def width(self) -> torch.Tensor: pass @property @abstractmethod def height(self) -> torch.Tensor: pass @property @abstractmethod def center_x(self) -> torch.Tensor: pass @property @abstractmethod def center_y(self) -> torch.Tensor: pass @property @abstractmethod def lt(self) -> torch.Tensor: pass @property @abstractmethod def rb(self) -> torch.Tensor: pass @property @abstractmethod def wh(self) -> torch.Tensor: pass @property @abstractmethod def center(self) -> torch.Tensor: pass @property def area(self) -> torch.Tensor: return self.width * self.height @classmethod @abstractmethod def _from1(cls, bboxes: 'BBoxes') -> torch.Tensor: """Convert to the first two coordinates of the bboxes. Args: bboxes: the bboxes to convert. Returns: A tensor representing the first two coordinates of the bboxes. """ @classmethod @abstractmethod def _from2(cls, bboxes: 'BBoxes') -> torch.Tensor: """Convert to the last two coordinates of the bboxes. Args: bboxes: the bboxes to convert. Returns: A tensor representing the last two coordinates of the bboxes. """
[docs] @classmethod def from_(cls, bboxes: 'BBoxes') -> Self: from1 = cls._from1(bboxes) from2 = cls._from2(bboxes) from_ = torch.cat([from1, from2], -1) (_, *args), kwargs = bboxes.__getstate__() return cls(from_, *args, **kwargs)
[docs] def to(self, cls: type[T]) -> T: return cls.from_(self)
[docs] def translate(self, offset_xy: tuple[float, float] | torch.Tensor) -> Self: if isinstance(offset_xy, tuple): offset_xy = self._tensor.new_tensor(offset_xy) bboxes = self.to(BBoxesXYXY) tensor = bboxes._tensor + torch.cat([offset_xy, offset_xy], -1) bboxes = bboxes.copy(tensor) return bboxes.to(self.__class__)
[docs] def round(self) -> Self: bboxes = self.to(BBoxesXYXY) if normalized := bboxes._normalized: bboxes = bboxes.denormalize() lt = bboxes.lt.floor() rb = bboxes.rb.ceil() tensor = torch.cat([lt, rb], -1) bboxes = bboxes.copy(tensor) if normalized: bboxes = bboxes.normalize() return bboxes.to(self.__class__)
[docs] def expand(self, ratio_wh: tuple[float, float] | torch.Tensor) -> Self: if isinstance(ratio_wh, tuple): ratio_wh = self._tensor.new_tensor(ratio_wh) bboxes = self.to(BBoxesCXCYWH) tensor = torch.cat([bboxes.center, bboxes.wh * ratio_wh], -1) bboxes = bboxes.copy(tensor) return bboxes.to(self.__class__)
[docs] def clamp(self) -> Self: bboxes = self.to(BBoxesXYXY) if bboxes._normalized: tensor = bboxes._tensor.clamp(0, 1) else: tensor = bboxes._tensor.clamp_min(0) tensor = tensor.clamp_max( bboxes._tensor.new_tensor(bboxes.divisor * 2), ) bboxes = bboxes.copy(tensor) return bboxes.to(self.__class__)
[docs] def indices( self, *, min_area: float | None = None, min_wh: tuple[float, float] | None = None, ) -> torch.Tensor: indices = self._tensor.new_ones(self.shape, dtype=torch.bool) if min_area is not None: indices &= self.area >= min_area if min_wh is not None: indices &= (self.wh >= torch.tensor(min_wh)).all(-1) return indices
[docs] def pairwise_intersections(self, other: 'BBoxes') -> torch.Tensor: lt = torch.maximum(self.lt, other.lt) rb = torch.minimum(self.rb, other.rb) wh = rb - lt wh = wh.clamp_min_(0) return wh[:, 0] * wh[:, 1]
def _pairwise_unions( self, other: 'BBoxes', intersections: torch.Tensor, ) -> torch.Tensor: return self.area + other.area - intersections
[docs] def pairwise_unions(self, other: 'BBoxes') -> torch.Tensor: intersections = self.pairwise_intersections(other) return self._pairwise_unions(other, intersections)
[docs] def pairwise_ious( self, other: 'BBoxes', eps: float = 1e-6, ) -> torch.Tensor: intersections = self.pairwise_intersections(other) unions = self._pairwise_unions(other, intersections) unions = unions.clamp_min(eps) return intersections / unions
[docs] def to_mask(self) -> torch.Tensor: w, h = self.divisor x = torch.arange(w, device=self._tensor.device) y = torch.arange(h, device=self._tensor.device) rearrange = einops.layers.torch.Rearrange('... -> ... 1') x_mask = (rearrange(self.left) <= x) & (x <= rearrange(self.right)) y_mask = (rearrange(self.top) <= y) & (y <= rearrange(self.bottom)) x_mask = einops.rearrange(x_mask, '... d -> ... 1 d') y_mask = rearrange(y_mask) mask = x_mask & y_mask return mask
[docs] class BBoxesXY__(BBoxes, ABC): @property def left(self) -> torch.Tensor: return self._tensor[:, 0] @property def top(self) -> torch.Tensor: return self._tensor[:, 1] @property def lt(self) -> torch.Tensor: return self._tensor[:, :2] @classmethod def _from1(cls, bboxes: BBoxes) -> torch.Tensor: return bboxes.lt
[docs] class BBoxesCXCY__(BBoxes, ABC): @property def center_x(self) -> torch.Tensor: return self._tensor[:, 0] @property def center_y(self) -> torch.Tensor: return self._tensor[:, 1] @property def center(self) -> torch.Tensor: return self._tensor[:, :2] @classmethod def _from1(cls, bboxes: BBoxes) -> torch.Tensor: return bboxes.center
[docs] class BBoxes__XY(BBoxes, ABC): # noqa: N801 @property def right(self) -> torch.Tensor: return self._tensor[:, 2] @property def bottom(self) -> torch.Tensor: return self._tensor[:, 3] @property def rb(self) -> torch.Tensor: return self._tensor[:, 2:] @classmethod def _from2(cls, bboxes: BBoxes) -> torch.Tensor: return bboxes.rb
[docs] class BBoxes__WH(BBoxes, ABC): # noqa: N801 @property def width(self) -> torch.Tensor: return self._tensor[:, 2] @property def height(self) -> torch.Tensor: return self._tensor[:, 3] @property def wh(self) -> torch.Tensor: return self._tensor[:, 2:] @classmethod def _from2(cls, bboxes: BBoxes) -> torch.Tensor: return bboxes.wh
[docs] @ODBBoxesRegistry.register_() class BBoxesXYXY(BBoxesXY__, BBoxes__XY): @property def width(self) -> torch.Tensor: return self.right - self.left @property def height(self) -> torch.Tensor: return self.bottom - self.top @property def center_x(self) -> torch.Tensor: return (self.left + self.right) / 2 @property def center_y(self) -> torch.Tensor: return (self.top + self.bottom) / 2 @property def wh(self) -> torch.Tensor: return self.rb - self.lt @property def center(self) -> torch.Tensor: return (self.lt + self.rb) / 2
[docs] def flatten(self) -> 'FlattenBBoxesXYXY': args, kwargs = self.copy(self._flatten()).__getstate__() return FlattenBBoxesXYXY(*args, **kwargs)
[docs] @ODBBoxesRegistry.register_() class BBoxesXYWH(BBoxesXY__, BBoxes__WH): @property def right(self) -> torch.Tensor: return self.left + self.width @property def bottom(self) -> torch.Tensor: return self.top + self.height @property def center_x(self) -> torch.Tensor: return self.left + self.width / 2 @property def center_y(self) -> torch.Tensor: return self.top + self.height / 2 @property def rb(self) -> torch.Tensor: return self.lt + self.wh @property def center(self) -> torch.Tensor: return self.lt + self.wh / 2
[docs] def flatten(self) -> 'FlattenBBoxesXYWH': args, kwargs = self.copy(self._flatten()).__getstate__() return FlattenBBoxesXYWH(*args, **kwargs)
[docs] @ODBBoxesRegistry.register_() class BBoxesCXCYWH(BBoxesCXCY__, BBoxes__WH): @property def left(self) -> torch.Tensor: return self.center_x - self.width / 2 @property def right(self) -> torch.Tensor: return self.center_x + self.width / 2 @property def top(self) -> torch.Tensor: return self.center_y - self.height / 2 @property def bottom(self) -> torch.Tensor: return self.center_y + self.height / 2 @property def lt(self) -> torch.Tensor: return self.center - self.wh / 2 @property def rb(self) -> torch.Tensor: return self.center + self.wh / 2
[docs] def flatten(self) -> 'FlattenBBoxesCXCYWH': args, kwargs = self.copy(self._flatten()).__getstate__() return FlattenBBoxesCXCYWH(*args, **kwargs)
[docs] class FlattenBBoxesMixin(FlattenMixin[BBox], BBoxes, ABC):
[docs] def intersections(self, other: 'FlattenBBoxesMixin') -> torch.Tensor: r"""Intersections. Args: other: :math:`n' \times 4`. Returns: :math:`n \times n'`. """ lt = torch.maximum( # [n, n', 2] einops.rearrange(self.lt, 'n1 lt -> n1 1 lt'), einops.rearrange(other.lt, 'n2 lt -> 1 n2 lt'), ) rb = torch.minimum( # [n, n', 2] einops.rearrange(self.rb, 'n1 rb -> n1 1 rb'), einops.rearrange(other.rb, 'n2 rb -> 1 n2 rb'), ) wh = rb - lt wh = wh.clamp_min_(0) return wh[..., 0] * wh[..., 1]
def __and__(self, other: 'FlattenBBoxesMixin') -> torch.Tensor: return self.intersections(other) def _unions( self, other: 'FlattenBBoxesMixin', intersections: torch.Tensor, ) -> torch.Tensor: r"""Unions. Args: other: :math:`n' \times 4`. intersections: :math:`n \times n'` Returns: :math:`n \times n'`. """ return self.area[:, None] + other.area[None, :] - intersections
[docs] def unions(self, other: 'FlattenBBoxesMixin') -> torch.Tensor: r"""Unions. Args: other: :math:`n' \times 4`. Returns: :math:`n \times n'`. """ intersections = self.intersections(other) return self._unions(other, intersections)
def __or__(self, other: 'FlattenBBoxesMixin') -> torch.Tensor: return self.unions(other)
[docs] def ious( self, other: 'FlattenBBoxesMixin', eps: float = 1e-6, ) -> torch.Tensor: r"""Intersections over unions. Args: other: :math:`n' \times 4`. eps: avoid division by zero. Returns: :math:`n \times n'`. """ intersections = self.intersections(other) unions = self._unions(other, intersections) unions = unions.clamp_min(eps) return intersections / unions
[docs] @ODBBoxesRegistry.register_() class FlattenBBoxesXYXY(FlattenBBoxesMixin, BBoxesXYXY): pass
[docs] @ODBBoxesRegistry.register_() class FlattenBBoxesXYWH(FlattenBBoxesMixin, BBoxesXYWH): pass
[docs] @ODBBoxesRegistry.register_() class FlattenBBoxesCXCYWH(FlattenBBoxesMixin, BBoxesCXCYWH): pass