Source code for todd.runners.callbacks.monitor

__all__ = [
    'MonitorCallback',
]

import contextlib
import logging
from typing import Any, TypeVar

from torch import nn

from ..memo import Memo
from ..registries import CallbackRegistry
from .base import BaseCallback

T = TypeVar('T', bound=nn.Module)


class Context:

    def __init__(self, logger: logging.Logger, **kwargs) -> None:
        self._logger = logger
        self._kwargs = kwargs

    def __enter__(self) -> None:
        pass

    def __exit__(self, *exc_info) -> None:
        if all(i is None for i in exc_info):
            return
        message = '\n'.join(f'{k}={v}' for k, v in self._kwargs.items())
        self._logger.exception("Unable to run " + message)


[docs] @CallbackRegistry.register_() class MonitorCallback(BaseCallback[T]):
[docs] def run_iter_context( self, exit_stack: contextlib.ExitStack, batch: Any, memo: Memo, ) -> None: super().run_iter_context(exit_stack, batch, memo) context = Context( self.runner.logger, iter_=self.runner.iter_, batch=batch, memo=memo, ) exit_stack.enter_context(context)
[docs] def run_epoch_context( self, exit_stack: contextlib.ExitStack, epoch_memo: Memo, memo: Memo, ) -> None: super().run_epoch_context(exit_stack, epoch_memo, memo) runner = self.epoch_based_trainer context = Context( runner.logger, epoch=runner.epoch, epoch_memo=epoch_memo, memo=memo, ) exit_stack.enter_context(context)