Source code for todd.runners.epoch_based_trainer

__all__ = [
    'EpochBasedTrainer',
]

import contextlib
import itertools
from collections import defaultdict
from typing import TypeVar

from torch import nn

from ..patches.torch import set_epoch
from ..registries import RunnerRegistry
from .memo import Memo
from .trainer import Trainer

T = TypeVar('T', bound=nn.Module)


[docs] @RunnerRegistry.register_() class EpochBasedTrainer(Trainer[T]):
[docs] def __init__(self, *args, epochs: int, **kwargs) -> None: super().__init__(*args, **kwargs) self._epochs = epochs
@property def iters(self) -> int: return self.iters_per_epoch * self._epochs @property def epochs(self) -> int: return self._epochs def _run_epoch(self, epoch_memo: Memo, memo: Memo) -> Memo: return super()._run(epoch_memo) def _setup_epoch(self, memo: Memo) -> Memo: epoch_memo = super()._setup() set_epoch(self._dataloader, self.epoch) epoch_memo.update( dataloader=( itertools.islice(self._dataloader, self.inner_iter, None) if self.inner_iter > 0 else self._dataloader ), epoch=defaultdict(list), ) return epoch_memo def _teardown_epoch(self, epoch_memo: Memo, memo: Memo) -> None: super()._teardown(epoch_memo) memo['epoch_memos'][self.epoch] = epoch_memo['epoch'] def _run(self, memo: Memo) -> Memo: while self.epoch < self._epochs: epoch_memo = self._setup_epoch(memo) if self._callbacks.should_break_epoch(epoch_memo, memo): break if self._callbacks.should_continue_epoch(epoch_memo, memo): continue self._callbacks.before_run_epoch(epoch_memo, memo) with contextlib.ExitStack() as exit_stack: self._callbacks.run_epoch_context( exit_stack, epoch_memo, memo, ) epoch_memo = self._run_epoch(epoch_memo, memo) self._callbacks.after_run_epoch(epoch_memo, memo) self._teardown_epoch(epoch_memo, memo) return memo def _setup(self) -> Memo: return dict(epoch_memos=dict()) def _teardown(self, memo: Memo) -> None: pass