Source code for todd.datasets.objects365

# pylint: disable=duplicate-code

__all__ = [
    'Objects365Dataset',
]

import pathlib
from collections import UserList
from dataclasses import dataclass
from typing import (
    TYPE_CHECKING,
    Any,
    Iterable,
    Literal,
    Mapping,
    TypedDict,
    cast,
)
from typing_extensions import Self

import torch
from pycocotools.coco import COCO

from ..registries import DatasetRegistry
from .access_layers import PILAccessLayer
from .coco import BaseDataset as BaseCOCODataset
from .coco import Keys as COCOKeys

if TYPE_CHECKING:
    from todd.tasks.object_detection import BBox, FlattenBBoxesXYWH

Split = Literal['train', 'val']


class Keys(COCOKeys):

    def __init__(self, *args, ignore: Iterable[str], **kwargs) -> None:
        super().__init__(*args, **kwargs)

        if ignore := set(ignore):
            ignore_indices = [i for i in range(len(self)) if self[i] in ignore]
            assert len(ignore_indices) == len(ignore)
            for i in ignore_indices[::-1]:
                self._image_ids.pop(i)

    def _getitem(self, image_id: int) -> str:
        item = super()._getitem(image_id)
        return item.removeprefix('images/')


@dataclass(frozen=True)
class Annotation:
    area: float
    bbox: 'BBox'
    category: int
    is_crowd: bool
    is_fake: bool
    is_reflected: bool

    @classmethod
    def load(
        cls,
        annotation: Mapping[str, Any],
        categories: Mapping[int, int],
    ) -> Self:
        return cls(
            annotation['area'],
            cast('BBox', annotation['bbox']),
            categories[annotation['category_id']],
            bool(annotation['iscrowd']),
            bool(annotation['isfake']),
            bool(annotation['isreflected']),
        )


class Annotations(UserList[Annotation]):

    @classmethod
    def load(
        cls,
        coco: COCO,
        image_id: int,
        categories: Mapping[int, int],
    ) -> Self:
        annotation_ids = coco.getAnnIds([image_id])
        annotations = coco.loadAnns(annotation_ids)
        return cls(
            Annotation.load(annotation, categories)
            for annotation in annotations
        )

    @property
    def areas(self) -> torch.Tensor:
        return torch.tensor([annotation.area for annotation in self])

    @property
    def bboxes(self) -> 'FlattenBBoxesXYWH':
        from todd.tasks.object_detection import FlattenBBoxesXYWH
        if self:
            bboxes = torch.tensor([annotation.bbox for annotation in self])
        else:
            bboxes = torch.zeros(0, 4)
        return FlattenBBoxesXYWH(bboxes)

    @property
    def categories(self) -> torch.Tensor:
        return torch.tensor([annotation.category for annotation in self])

    @property
    def is_crowd(self) -> torch.Tensor:
        return torch.tensor([annotation.is_crowd for annotation in self])

    @property
    def is_fake(self) -> torch.Tensor:
        return torch.tensor([annotation.is_fake for annotation in self])

    @property
    def is_reflected(self) -> torch.Tensor:
        return torch.tensor([annotation.is_reflected for annotation in self])


class T(TypedDict):
    id_: str
    image: torch.Tensor
    annotations: Annotations


[docs] @DatasetRegistry.register_() class Objects365Dataset(BaseCOCODataset[COCO, T]): _keys: Keys DATA_ROOT = pathlib.Path('data/objects365') ANNOTATIONS_ROOT = DATA_ROOT / 'annotations' IGNORE_KEYS = ( 'v1/patch6/objects365_v1_00320532', 'v1/patch6/objects365_v1_00320534', 'v2/patch16/objects365_v2_00908726', )
[docs] def __init__( self, *args, split: Split, version: Literal['v1', 'v2'] = 'v2', access_layer: PILAccessLayer | None = None, annotations_file: pathlib.Path | str | None = None, **kwargs, ) -> None: self._split = split if access_layer is None: access_layer = PILAccessLayer( data_root=str(self.DATA_ROOT), task_name=split, suffix=self.SUFFIX, ) if annotations_file is None: annotations_file = ( self.ANNOTATIONS_ROOT / f'zhiyuan_obj{version}_{split}.json' ) coco = COCO(annotations_file) self._categories = { category_id: i for i, category_id in enumerate(coco.getCatIds()) } super().__init__(*args, api=coco, access_layer=access_layer, **kwargs)
[docs] def build_keys(self) -> Keys: ignore_keys: Iterable[str] if self._split == 'train': ignore_keys = self.IGNORE_KEYS elif self._split == 'val': ignore_keys = tuple() else: raise ValueError(f"Unknown split: {self._split}") return Keys(self._api, self.SUFFIX, ignore=ignore_keys)
def __getitem__(self, index: int) -> T: key, image = self._access(index) tensor = self._transform(image) annotations = ( Annotations.load( self._api, self._keys.image_ids[index], self._categories, ) if self._load_annotations else Annotations() ) return T(id_=key, image=tensor, annotations=annotations)