Source code for todd.runners.iter_based_trainer

__all__ = [
    'IterBasedTrainer',
]

import itertools
from typing import Any, Generator, TypeVar

from torch import nn

from ..patches.torch import set_epoch
from ..registries import RunnerRegistry
from .memo import Memo
from .trainer import Trainer

T = TypeVar('T', bound=nn.Module)


[docs] @RunnerRegistry.register_() class IterBasedTrainer(Trainer[T]):
[docs] def __init__(self, *args, iters: int, **kwargs) -> None: super().__init__(*args, **kwargs) # python interprets numbers like 2e3 as floats self._iters = int(iters)
@property def iters(self) -> int: return self._iters def _iterate_dataloader(self) -> Generator[Any, None, None]: if self.inner_iter > 0: set_epoch(self._dataloader, self.epoch) yield from itertools.islice( self._dataloader, self.inner_iter, self.iters - self.iters_per_epoch * self.epoch, ) while self._iter < self.iters: assert self.inner_iter == 0 set_epoch(self._dataloader, self.epoch) yield from itertools.islice( self._dataloader, self.iters - self._iter, ) def _setup(self) -> Memo: memo = super()._setup() memo['dataloader'] = self._iterate_dataloader() return memo