Source code for todd.runners.callbacks.base

# pylint: disable=pointless-statement

__all__ = [
    'BaseCallback',
]

import contextlib
from typing import Any, TypeVar

from torch import nn

from ...utils import StateDictMixin
from ..memo import Memo
from ..registries import CallbackRegistry
from ..utils import RunnerHolderMixin

T = TypeVar('T', bound=nn.Module)


[docs] @CallbackRegistry.register_() class BaseCallback(RunnerHolderMixin[T], StateDictMixin):
[docs] def should_break(self, batch: Any, memo: Memo) -> bool: """Determine whether to break the run loop. Args: batch: inputs. memo: runtime memory. Returns: Whether to break the run loop. Override this method for early stopping, error detection, etc. By default, this method returns `False` and the run loop ends normally when the dataloader is exhausted. """ return False
[docs] def should_continue(self, batch: Any, memo: Memo) -> bool: """Determine whether to skip the current iteration. Args: batch: inputs. memo: runtime memory. Returns: Whether to skip the current iteration. """ return False
[docs] def before_run_iter(self, batch: Any, memo: Memo) -> None: pass
[docs] def run_iter_context( self, exit_stack: contextlib.ExitStack, batch: Any, memo: Memo, ) -> None: pass
[docs] def after_run_iter(self, batch: Any, memo: Memo) -> None: pass
[docs] def should_break_epoch(self, epoch_memo: Memo, memo: Memo) -> bool: self.epoch_based_trainer return False
[docs] def should_continue_epoch(self, epoch_memo: Memo, memo: Memo) -> bool: self.epoch_based_trainer return False
[docs] def before_run_epoch(self, epoch_memo: Memo, memo: Memo) -> None: self.epoch_based_trainer
[docs] def run_epoch_context( self, exit_stack: contextlib.ExitStack, epoch_memo: Memo, memo: Memo, ) -> None: self.epoch_based_trainer
[docs] def after_run_epoch(self, epoch_memo: Memo, memo: Memo) -> None: self.epoch_based_trainer
[docs] def before_run(self, memo: Memo) -> None: pass
[docs] def after_run(self, memo: Memo) -> None: pass