Source code for todd.datasets.access_layers.hf

__all__ = [
    'HFAccessLayer',
]

import os
from typing import Iterator, TypeVar

from datasets import Dataset, DatasetDict, load_dataset

from ...bases.configs import Config
from ...bases.registries import BuildPreHookMixin, Item, RegistryMeta
from ...loggers import logger
from ..registries import AccessLayerRegistry
from .base import BaseAccessLayer

VT = TypeVar('VT')


[docs] @AccessLayerRegistry.register_() class HFAccessLayer(BuildPreHookMixin, BaseAccessLayer[int, VT]):
[docs] def __init__(self, *args, datasets: DatasetDict, **kwargs) -> None: super().__init__(*args, **kwargs) self._datasets = datasets
[docs] @classmethod def datasets_build_pre_hook( cls, config: Config, registry: RegistryMeta, item: Item, ) -> Config: if os.getenv('HF_DATASETS_OFFLINE') != '1': logger.warning("'HF_DATASETS_OFFLINE=1' is not set.") config.datasets = load_dataset( **config.datasets, cache_dir=config.data_root, ) return config
[docs] @classmethod def build_pre_hook( cls, config: Config, registry: RegistryMeta, item: Item, ) -> Config: config = super().build_pre_hook(config, registry, item) config = cls.datasets_build_pre_hook(config, registry, item) return config
@property def dataset(self) -> Dataset: return self._datasets[self._task_name] @property def exists(self) -> bool: return True
[docs] def touch(self) -> None: pass
def __len__(self) -> int: return len(self.dataset) def __iter__(self) -> Iterator[int]: return iter(range(len(self))) def __getitem__(self, key: int) -> VT: return self.dataset[key] def __delitem__(self, *args, **kwargs) -> None: raise NotImplementedError def __setitem__(self, *args, **kwargs) -> None: raise NotImplementedError