Source code for todd.datasets.imagenet

__all__ = [
    'ImageNetDataset',
]

import json
import os
import pathlib
from abc import ABC
from typing import Literal, TypedDict

import torch

from ..registries import DatasetRegistry
from .access_layers import PILAccessLayer
from .base import KeysProtocol
from .pil import PILDataset

Split = Literal['train', 'val']


class Synset(TypedDict):
    WNID: str
    words: str
    gloss: str
    num_children: int
    children: list[int]
    wordnet_height: int
    num_train_images: int


Synsets = dict[int, Synset]


class Annotation(TypedDict):
    name: str
    synset_id: int


Annotations = list[Annotation]


class Keys(KeysProtocol[str]):  # pylint: disable=unsubscriptable-object

    def __init__(
        self,
        annotations: Annotations,
        synsets: Synsets,
        suffix: str,
    ) -> None:
        self._annotations = annotations
        self._synsets = synsets
        self._suffix = suffix

    def __len__(self) -> int:
        return len(self._annotations)

    def __getitem__(self, index: int) -> str:
        annotation = self._annotations[index]
        return os.path.join(
            self._synsets[annotation['synset_id']]['WNID'],
            annotation['name'].removesuffix(f'.{self._suffix}'),
        )


class T(TypedDict):
    id_: str
    image: torch.Tensor
    category: int


[docs] @DatasetRegistry.register_() class ImageNetDataset(PILDataset[T], ABC): DATA_ROOT = pathlib.Path('data/imagenet') ANNOTATIONS_ROOT = DATA_ROOT / 'annotations' SYNSETS_FILE = DATA_ROOT / 'synsets.json' SUFFIX = 'JPEG'
[docs] def __init__( self, *args, split: Split, access_layer: PILAccessLayer | None = None, annotations_file: pathlib.Path | str | None = None, synsets_file: pathlib.Path | str | None = None, **kwargs, ) -> None: if access_layer is None: access_layer = PILAccessLayer( data_root=str(self.DATA_ROOT), task_name=split, subfolder_action='walk', suffix=self.SUFFIX, ) if annotations_file is None: annotations_file = self.ANNOTATIONS_ROOT / f'{split}.json' elif isinstance(annotations_file, str): annotations_file = pathlib.Path(annotations_file) if synsets_file is None: synsets_file = self.SYNSETS_FILE elif isinstance(synsets_file, str): synsets_file = pathlib.Path(synsets_file) with annotations_file.open() as f: self._annotations: Annotations = json.load(f) with self.SYNSETS_FILE.open() as f: synsets: dict[str, Synset] = json.load(f) synsets_: Synsets = {int(k): v for k, v in synsets.items()} self._synsets = synsets_ self._categories = { synset_id: i for i, synset_id in enumerate(sorted(synsets_)) } super().__init__(*args, access_layer=access_layer, **kwargs)
[docs] def build_keys(self) -> Keys: return Keys(self._annotations, self._synsets, self.SUFFIX)
def __getitem__(self, index: int) -> T: key, image = self._access(index) tensor = self._transform(image) synset_id = self._annotations[index]['synset_id'] category = self._categories[synset_id] return T(id_=key, image=tensor, category=category)