Source code for todd.datasets.laion_aesthetics

__all__ = [
    'LAIONAestheticsDataset',
]

import csv
import pathlib
from abc import ABC
from typing import Literal, TypedDict

import torch

from ..registries import DatasetRegistry
from .access_layers import PILAccessLayer
from .base import KeysProtocol
from .pil import PILDataset

Split = Literal['v2_6.5plus']


class Annotation(TypedDict):
    filename: str
    caption: str
    score: float
    url: str


Annotations = list[Annotation]


class Keys(KeysProtocol[str]):  # pylint: disable=unsubscriptable-object

    def __init__(self, annotations: Annotations) -> None:
        self._annotations = annotations

    def __len__(self) -> int:
        return len(self._annotations)

    def __getitem__(self, index: int) -> str:
        return self._annotations[index]['filename']


class T(TypedDict):
    id_: str
    image: torch.Tensor
    caption: str
    score: float


[docs] @DatasetRegistry.register_() class LAIONAestheticsDataset(PILDataset[T], ABC): DATA_ROOT = pathlib.Path('data/laion/aesthetics') ANNOTATIONS_ROOT = DATA_ROOT / 'annotations' SUFFIX = None
[docs] def __init__( self, *args, split: Split, access_layer: PILAccessLayer | None = None, annotations_file: pathlib.Path | str | None = None, **kwargs, ) -> None: if access_layer is None: access_layer = PILAccessLayer( data_root=str(self.DATA_ROOT), task_name=split, subfolder_action='none', suffix=self.SUFFIX, ) if annotations_file is None: annotations_file = self.ANNOTATIONS_ROOT / f'{split}.tsv' elif isinstance(annotations_file, str): annotations_file = pathlib.Path(annotations_file) with annotations_file.open() as f: self._annotations = [ Annotation( filename=annotation[0], caption='\t'.join(annotation[1:-2]), score=float(annotation[-2]), url=annotation[-1], ) for annotation in csv.reader(f, delimiter='\t') ] super().__init__(*args, access_layer=access_layer, **kwargs)
[docs] def build_keys(self) -> Keys: return Keys(self._annotations)
def __getitem__(self, index: int) -> T: key, image = self._access(index) tensor = self._transform(image) annotation = self._annotations[index] return T( id_=key, image=tensor, caption=annotation['caption'], score=annotation['score'], )