Source code for todd.datasets.access_layers.pth

__all__ = [
    'PthAccessLayer',
]

from typing import TypeVar

import torch

from ..registries import AccessLayerRegistry
from .folder import FolderAccessLayer
from .suffix import SuffixMixin

VT = TypeVar('VT')


[docs] @AccessLayerRegistry.register_() class PthAccessLayer(SuffixMixin[VT], FolderAccessLayer[VT]):
[docs] def __init__(self, *args, **kwargs) -> None: super().__init__(*args, suffix='pth', **kwargs)
def __getitem__(self, key: str) -> VT: return torch.load(self._file(key), map_location='cpu') def __setitem__(self, key: str, value: VT) -> None: torch.save(value, self._file(key))