Source code for todd.datasets.satin

__all__ = [
    'SATINDataset',
]

import io
import pathlib
from typing import Any, Literal, TypedDict

import datasets
import torch
import torchvision.transforms.functional as F
from PIL import Image

from ..bases.configs import Config
from ..patches.pil import convert_rgb
from ..registries import DatasetRegistry
from .access_layers import HFAccessLayer
from .base import BaseDataset
from .index import IndexKeys
from .registries import AccessLayerRegistry


class T(TypedDict):
    id_: int
    image: torch.Tensor
    data: dict[str, Any]


Split = Literal['SAT-4', 'SAT-6', 'NASC-TG2', 'WHU-RS19', 'RSSCN7', 'RS_C11',
                'SIRI-WHU', 'EuroSAT', 'NWPU-RESISC45', 'PatternNet',
                'RSD46-WHU', 'GID', 'CLRS', 'Optimal-31',
                'Airbus-Wind-Turbines-Patches', 'USTC_SmokeRS',
                'Canadian_Cropland', 'Ships-In-Satellite-Imagery',
                'Satellite-Images-of-Hurricane-Damage',
                'Brazilian_Coffee_Scenes', 'Brazilian_Cerrado-Savanna_Scenes',
                'Million-AID', 'UC_Merced_LandUse_MultiLabel', 'MLRSNet',
                'MultiScene', 'RSI-CB256', 'AID_MultiLabel']


[docs] @DatasetRegistry.register_() class SATINDataset(BaseDataset[T, int, dict[str, Any]]): DATA_ROOT = pathlib.Path('data/satin')
[docs] def __init__( self, *args, split: Split, access_layer: HFAccessLayer | None = None, **kwargs, ) -> None: if access_layer is None: access_layer = AccessLayerRegistry.build( Config( type=HFAccessLayer.__name__, data_root=str(self.DATA_ROOT), task_name=str(datasets.Split.TRAIN), datasets=dict( path='jonathan-roberts1/satin', name=split, trust_remote_code=True, ), ), ) super().__init__(*args, access_layer=access_layer, **kwargs) self._split = split
[docs] def build_keys(self) -> IndexKeys: return IndexKeys(len(self._access_layer))
def _transform(self, image: Image.Image) -> torch.Tensor: if self._transforms is None: return F.pil_to_tensor(image) return self._transforms(image) def __getitem__(self, index: int) -> T: key, data = self._access(index) image = data.pop('image') if not isinstance(image, Image.Image): image = Image.open(io.BytesIO(image['bytes'])) image = convert_rgb(image) tensor = self._transform(image) return T(id_=key, image=tensor, data=data)