Source code for todd.runners.callbacks.lr

__all__ = [
    'LRScheduleCallback',
    'LRScaleCallback',
]

from typing import Any, Mapping, TypeVar

import torch
from torch import nn

from ...bases.configs import Config
from ...patches.torch import get_rank, get_world_size
from ...registries import LRSchedulerRegistry
from ..memo import Memo
from ..registries import CallbackRegistry
from .base import BaseCallback
from .interval import IntervalMixin

T = TypeVar('T', bound=nn.Module)


[docs] @CallbackRegistry.register_() class LRScheduleCallback(IntervalMixin[T], BaseCallback[T]):
[docs] def __init__( self, *args, lr_scheduler: Config, interval: int = 1, **kwargs, ) -> None: super().__init__(*args, interval=interval, **kwargs) self._lr_scheduler_config = lr_scheduler
[docs] def bind(self, *args, **kwargs) -> None: super().bind(*args, **kwargs) self._lr_scheduler: torch.optim.lr_scheduler.LRScheduler = ( LRSchedulerRegistry.build( self._lr_scheduler_config, optimizer=self.trainer.optimizer, ) )
[docs] def after_run_iter(self, batch: Any, memo: Memo) -> None: super().after_run_iter(batch, memo) if 'log' in memo: memo['log']['lr'] = [ f'{lr:.3e}' for lr in self._lr_scheduler.get_last_lr() ] if self._should_run_iter(): self._lr_scheduler.step()
[docs] def after_run_epoch(self, epoch_memo: Memo, memo: Memo) -> None: super().after_run_epoch(epoch_memo, memo) if self._should_run_epoch(): self._lr_scheduler.step()
[docs] def load_state_dict( self, state_dict: Mapping[str, Any], *args, **kwargs, ) -> None: super().load_state_dict(state_dict, *args, **kwargs) self._lr_scheduler.load_state_dict(state_dict['lr_scheduler'])
[docs] def state_dict(self, *args, **kwargs) -> dict[str, Any]: state_dict = super().state_dict(*args, **kwargs) state_dict['lr_scheduler'] = self._lr_scheduler.state_dict() return state_dict
[docs] @CallbackRegistry.register_() class LRScaleCallback(BaseCallback[T]):
[docs] def __init__(self, *args, lr_scaler: Config, **kwargs) -> None: super().__init__(*args, **kwargs) self._lr_scaler_config = lr_scaler
[docs] def bind(self, *args, **kwargs) -> None: super().bind(*args, **kwargs) trainer = self.trainer dataloader = trainer.dataloader optimizer = trainer.optimizer batch_size = ( 1 if dataloader.batch_size is None else dataloader.batch_size ) batch_size = get_world_size() * batch_size base_batch_size = self._lr_scaler_config.base_batch_size lr_scaler = batch_size / base_batch_size if 'lr' in optimizer.defaults: optimizer.defaults['lr'] *= lr_scaler for param_group in optimizer.param_groups: if 'lr' in param_group: param_group['lr'] *= lr_scaler if get_rank() == 0: self.runner.logger.info( f"{base_batch_size=} {batch_size=} {lr_scaler=:.3f}", )