Source code for todd.runners.trainer

__all__ = [
    'Trainer',
]

from abc import ABC
from typing import TYPE_CHECKING, Any, Mapping, TypeVar

import torch
from torch import nn

from ..bases.configs import Config
from ..bases.registries import Item, RegistryMeta
from ..registries import RunnerRegistry
from .base import BaseRunner
from .memo import Memo

if TYPE_CHECKING:
    from .strategies import BaseStrategy

T = TypeVar('T', bound=nn.Module)


[docs] @RunnerRegistry.register_() class Trainer(BaseRunner[T], ABC):
[docs] def __init__( self, *args, optimizer: torch.optim.Optimizer, **kwargs, ) -> None: self._optimizer = optimizer super().__init__(*args, **kwargs)
[docs] @classmethod def optimizer_build_pre_hook( cls, config: Config, registry: RegistryMeta, item: Item, ) -> Config: strategy: 'BaseStrategy[T]' = config.strategy model: nn.Module = config.model config.optimizer = strategy.build_optimizer(config.optimizer, model) return config
[docs] @classmethod def build_pre_hook( cls, config: Config, registry: RegistryMeta, item: Item, ) -> Config: config = super().build_pre_hook(config, registry, item) config = cls.optimizer_build_pre_hook(config, registry, item) return config
@property def iters_per_epoch(self) -> int: return len(self._dataloader) @property def inner_iter(self) -> int: return self._iter % self.iters_per_epoch @property def epoch(self) -> int: return self._iter // self.iters_per_epoch @property def optimizer(self) -> torch.optim.Optimizer: return self._optimizer def _setup(self) -> Memo: self._model.train() return super()._setup()
[docs] def state_dict(self, *args, **kwargs) -> dict[str, Any]: state_dict = super().state_dict(*args, **kwargs) state_dict['optim'] = self._strategy.optim_state_dict(*args, **kwargs) return state_dict
[docs] def load_state_dict( self, state_dict: Mapping[str, Any], *args, **kwargs, ) -> None: super().load_state_dict(state_dict, *args, **kwargs) self._strategy.load_optim_state_dict( state_dict['optim'], *args, **kwargs, )