Source code for todd.datasets.coco

__all__ = [
    'coco_url',
    'COCODataset',
]

import pathlib
from abc import ABC, abstractmethod
from collections import UserList
from dataclasses import dataclass
from typing import (
    TYPE_CHECKING,
    Generic,
    Iterable,
    Literal,
    Mapping,
    TypedDict,
    TypeVar,
    cast,
)
from typing_extensions import Self

import torch
from pycocotools.coco import COCO

from ..registries import DatasetRegistry
from .access_layers import PILAccessLayer
from .base import KeysProtocol
from .pil import PILDataset

if TYPE_CHECKING:
    from pycocotools.coco import _Annotation

    from todd.tasks.object_detection import BBox, FlattenBBoxesXYWH

URL = 'http://images.cocodataset.org/'
Split = Literal['train', 'val']
Year = Literal[2014, 2017]


[docs] def coco_url(split: Split, year: Year, id_: int) -> str: return f'{URL}{split}{year}/{id_:012d}.jpg'
class BaseKeys(KeysProtocol[str], ABC): def __init__(self, image_ids: Iterable[int], suffix: str) -> None: self._image_ids = list(image_ids) self._suffix = suffix @property def image_ids(self) -> list[int]: return self._image_ids def __len__(self) -> int: return len(self._image_ids) @abstractmethod def _getitem(self, image_id: int) -> str: pass def __getitem__(self, index: int) -> str: item = self._getitem(self._image_ids[index]) return item.removesuffix(f'.{self._suffix}') class Keys(BaseKeys): def __init__(self, coco: COCO, *args, **kwargs) -> None: self._coco = coco super().__init__(coco.getImgIds(), *args, **kwargs) def _getitem(self, image_id: int) -> str: image, = self._coco.loadImgs(image_id) return image['file_name'] @dataclass(frozen=True) class Annotation: mask: torch.Tensor area: float is_crowd: bool bbox: 'BBox' category: int @classmethod def load( cls, coco: COCO, annotation: '_Annotation', categories: Mapping[int, int], ) -> Self: mask = ( torch.from_numpy(coco.annToMask(annotation)) if 'segmentation' in annotation else torch.zeros(1) ) return cls( mask, annotation['area'], bool(annotation['iscrowd']), cast('BBox', annotation['bbox']), categories[annotation['category_id']], ) class Annotations(UserList[Annotation]): @classmethod def load( cls, coco: COCO, image_id: int, categories: Mapping[int, int], ) -> Self: annotation_ids = coco.getAnnIds([image_id]) annotations = coco.loadAnns(annotation_ids) return cls( Annotation.load(coco, annotation, categories) for annotation in annotations ) @property def masks(self) -> torch.Tensor: return torch.stack([annotation.mask for annotation in self]) @property def areas(self) -> torch.Tensor: return torch.tensor([annotation.area for annotation in self]) @property def is_crowd(self) -> torch.Tensor: return torch.tensor([annotation.is_crowd for annotation in self]) @property def bboxes(self) -> 'FlattenBBoxesXYWH': from todd.tasks.object_detection import FlattenBBoxesXYWH if self: bboxes = torch.tensor([annotation.bbox for annotation in self]) else: bboxes = torch.zeros(0, 4) return FlattenBBoxesXYWH(bboxes) @property def categories(self) -> torch.Tensor: return torch.tensor([annotation.category for annotation in self]) APIType = TypeVar('APIType') # pylint: disable=invalid-name DataType = TypeVar('DataType') # pylint: disable=invalid-name class BaseDataset( PILDataset[DataType], Generic[APIType, DataType], ABC, ): SUFFIX = 'jpg' def __init__( self, *args, load_annotations: bool = True, api: APIType, **kwargs, ) -> None: self._load_annotations = load_annotations self._api = api super().__init__(*args, **kwargs) @property def api(self) -> APIType: return self._api class T(TypedDict): id_: str image: torch.Tensor annotations: Annotations
[docs] @DatasetRegistry.register_() class COCODataset(BaseDataset[COCO, T]): _keys: Keys DATA_ROOT = pathlib.Path('data/coco') ANNOTATIONS_ROOT = DATA_ROOT / 'annotations'
[docs] def __init__( self, *args, split: Split, year: Year = 2017, access_layer: PILAccessLayer | None = None, annotations_file: pathlib.Path | str | None = None, **kwargs, ) -> None: split_year = f'{split}{year}' if access_layer is None: access_layer = PILAccessLayer( data_root=str(self.DATA_ROOT), task_name=split_year, suffix=self.SUFFIX, ) if annotations_file is None: annotations_file = ( self.ANNOTATIONS_ROOT / f'instances_{split_year}.json' ) coco = COCO(annotations_file) self._categories = { category_id: i for i, category_id in enumerate(coco.getCatIds()) } super().__init__(*args, api=coco, access_layer=access_layer, **kwargs)
[docs] def build_keys(self) -> Keys: return Keys(self._api, self.SUFFIX)
def __getitem__(self, index: int) -> T: key, image = self._access(index) tensor = self._transform(image) annotations = ( Annotations.load( self._api, self._keys.image_ids[index], self._categories, ) if self._load_annotations else Annotations() ) return T(id_=key, image=tensor, annotations=annotations)