Source code for todd.runners.callbacks.interval

__all__ = [
    'IntervalMixin',
]

from typing import TypeVar

from torch import nn

from ...utils import Store
from .base import BaseCallback

T = TypeVar('T', bound=nn.Module)


[docs] class IntervalMixin(BaseCallback[T]):
[docs] def __init__( self, *args, interval: int = 0, by_epoch: bool = False, **kwargs, ) -> None: super().__init__(*args, **kwargs) if Store.DRY_RUN: interval = 1 self._interval = interval self._by_epoch = by_epoch
def __should_run(self, step: int) -> bool: return self._interval > 0 and step % self._interval == 0 def _should_run_iter(self) -> bool: return not self._by_epoch and self.__should_run(self.runner.iter_) def _should_run_epoch(self) -> bool: return ( self._by_epoch and self.__should_run(self.epoch_based_trainer.epoch) )