Source code for todd.runners.strategies.base

__all__ = [
    'BaseStrategy',
]

from typing import Any, Mapping, TypeVar, cast

import torch
from torch import nn

from ...bases.configs import Config
from ...patches.torch import load_state_dict, load_state_dict_
from ...registries import OptimizerRegistry
from ...utils import StateDictMixin
from ..registries import StrategyRegistry
from ..utils import RunnerHolderMixin

T = TypeVar('T', bound=nn.Module)


[docs] @StrategyRegistry.register_() class BaseStrategy(RunnerHolderMixin[T], StateDictMixin):
[docs] def __init__( self, *args, setup: Config | None = None, **kwargs, ) -> None: super().__init__(*args, **kwargs) if setup is None: setup = Config() self.setup(setup)
[docs] def setup(self, config: Config) -> None: pass
[docs] def compile_model(self, model: nn.Module, config: Config) -> nn.Module: model.forward = torch.compile(model.forward, **config) return model
[docs] def map_model(self, model: nn.Module, config: Config) -> nn.Module: return model
[docs] def wrap_model(self, model: nn.Module, config: Config) -> T: return cast(T, model)
[docs] def build_optimizer( self, config: Config, model: nn.Module, ) -> torch.optim.Optimizer: return OptimizerRegistry.build(config, model=model)
@property def module(self) -> nn.Module: return self.runner.model
[docs] def model_state_dict(self, *args, **kwargs) -> dict[str, Any]: return self.module.state_dict(*args, **kwargs)
[docs] def load_model_state_dict( self, state_dict: Mapping[str, Any], *args, **kwargs, ) -> None: load_state_dict( self.module, state_dict, *args, logger=self.runner.logger, **kwargs, )
[docs] def load_model_from( self, f: torch.serialization.FileLike | list[torch.serialization.FileLike], *args, **kwargs, ) -> None: model_state_dict = load_state_dict_(f, logger=self.runner.logger) self.load_model_state_dict(model_state_dict, *args, **kwargs)
[docs] def optim_state_dict(self, *args, **kwargs) -> dict[str, Any]: return self.trainer.optimizer.state_dict()
[docs] def load_optim_state_dict( self, state_dict: Mapping[str, Any], *args, **kwargs, ) -> None: state_dict = dict(state_dict) self.trainer.optimizer.load_state_dict(state_dict)