Source code for todd.patches.torch.utils

__all__ = [
    'set_epoch',
    'PrefetchDataLoader',
]

from typing import Any, Iterator, TypeVar

import torch
from torch.utils.data import DataLoader

T = TypeVar('T')


[docs] def set_epoch(dataloader: DataLoader, epoch: int) -> None: samplers = [ dataloader.sampler, dataloader.batch_sampler, getattr(dataloader.batch_sampler, 'sampler', None), ] for sampler in samplers: set_epoch_ = getattr(sampler, 'set_epoch', None) if set_epoch_ is not None: set_epoch_(epoch)
def cuda(x: Any) -> Any: if isinstance(x, torch.Tensor): return x.cuda(non_blocking=True) return x
[docs] class PrefetchDataLoader(DataLoader[T]):
[docs] def __init__(self, *args, **kwargs) -> None: from ...utils import NestedCollectionUtils super().__init__(*args, **kwargs) self._stream = torch.cuda.Stream() self._utils = NestedCollectionUtils()
def _prefetch(self, iter_: Iterator[Any]) -> Iterator[Any] | None: try: batch = next(iter_) except StopIteration: return None with torch.cuda.stream(self._stream): return self._utils.map(cuda, batch) # type: ignore[arg-type] def __iter__(self) -> Iterator[Any]: # type: ignore[override] iter_ = super().__iter__() batch = self._prefetch(iter_) while batch is not None: torch.cuda.current_stream().wait_stream(self._stream) next_batch = self._prefetch(iter_) yield batch batch = next_batch