Source code for todd.runners.callbacks.log

__all__ = [
    'LogCallback',
]

import datetime
import logging
from typing import Any, TypeVar

import torch
from torch import nn

from ...bases.configs import Config
from ...loggers import Formatter
from ...patches.torch import get_rank
from ...utils import Store, collect_env_, get_timestamp
from ..memo import Memo
from ..registries import CallbackRegistry
from ..utils import BaseETA, ETARegistry
from .base import BaseCallback
from .interval import IntervalMixin

T = TypeVar('T', bound=nn.Module)


[docs] @CallbackRegistry.register_() class LogCallback(IntervalMixin[T], BaseCallback[T]):
[docs] def __init__( self, *args, collect_env: Config | None = None, with_file_handler: bool = False, eta: Config | None = None, **kwargs, ) -> None: super().__init__(*args, **kwargs) self._collect_env = collect_env self._with_file_handler = with_file_handler self._eta_config = eta
[docs] def bind(self, *args, **kwargs) -> None: super().bind(*args, **kwargs) if get_rank() > 0: return if self._with_file_handler: file = self.runner.work_dir / f'{get_timestamp()}.log' handler = logging.FileHandler(file) handler.setFormatter(Formatter()) self.runner.logger.addHandler(handler) if self._collect_env is not None: env = collect_env_(**self._collect_env) self.runner.logger.info(env)
[docs] def before_run(self, memo: Memo) -> None: super().before_run(memo) self._eta: BaseETA | None = ( None if self._eta_config is None else ETARegistry.build( self._eta_config, start=self.runner.iter_ - 1, end=self.runner.iters, ) )
[docs] def before_run_iter(self, batch: Any, memo: Memo) -> None: super().before_run_iter(batch, memo) if get_rank() == 0 and self._should_run_iter(): memo['log'] = dict()
[docs] def after_run_iter(self, batch: Any, memo: Memo) -> None: super().after_run_iter(batch, memo) if 'log' not in memo: return prefix = f"Iter [{self.runner.iter_}/{self.runner.iters}] " if self._eta is not None: eta = self._eta(self.runner.iter_) eta = round(eta) prefix += f"ETA {str(datetime.timedelta(seconds=eta))} " if Store.cuda: # pylint: disable=using-constant-test max_memory_allocated = max( torch.cuda.max_memory_allocated(i) for i in range(torch.cuda.device_count()) ) torch.cuda.reset_peak_memory_stats() prefix += f"Memory {max_memory_allocated / 1024 ** 2:.2f}M " log: dict[str, Any] = memo.pop('log') message = ' '.join(f'{k}={v}' for k, v in log.items() if v is not None) self.runner.logger.info(prefix + message)
[docs] def before_run_epoch(self, epoch_memo: Memo, memo: Memo) -> None: super().before_run_epoch(epoch_memo, memo) runner = self.epoch_based_trainer if get_rank() == 0: runner.logger.info( "Epoch [%d/%d]", runner.epoch + 1, runner.epochs, )