Source code for todd.datasets.pil

__all__ = [
    'PILDataset',
]

from abc import ABC
from typing import TypeVar

import torch
import torchvision.transforms.functional as F
from PIL import Image

from ..patches.pil import convert_rgb
from .base import BaseDataset

VT = Image.Image
T = TypeVar('T')


[docs] class PILDataset(BaseDataset[T, str, VT], ABC): def _access(self, index: int) -> tuple[str, VT]: key, image = super()._access(index) image = convert_rgb(image) return key, image def _transform(self, image: Image.Image) -> torch.Tensor: if self._transforms is None: return F.pil_to_tensor(image) return self._transforms(image)