Source code for todd.runners.callbacks.checkpoint

__all__ = [
    'CheckpointCallback',
]

import pathlib
from typing import TypeVar

import torch
from torch import nn

from ...bases.configs import Config
from ...patches.torch import get_rank
from ..memo import Memo
from ..registries import CallbackRegistry
from .base import BaseCallback
from .interval import IntervalMixin

T = TypeVar('T', bound=nn.Module)


[docs] @CallbackRegistry.register_() class CheckpointCallback(IntervalMixin[T], BaseCallback[T]):
[docs] def __init__( self, *args, state_dict: Config | None = None, load_state_dict: Config | None = None, **kwargs, ) -> None: super().__init__(*args, **kwargs) if state_dict is None: state_dict = Config() self._state_dict = state_dict if load_state_dict is None: load_state_dict = Config() self._load_state_dict = load_state_dict
[docs] def bind(self, *args, **kwargs) -> None: super().bind(*args, **kwargs) self.work_dir.mkdir(parents=True, exist_ok=True) if self.runner.auto_resume and self.latest_checkpoint_dir.exists(): load_from = self.latest_checkpoint_dir elif self.runner.load_from is not None: load_from = pathlib.Path(self.runner.load_from) assert load_from.exists() else: load_from = None if load_from is not None: if get_rank() == 0: self.runner.logger.info("Loading from %s", load_from) state_dict = { f.stem: torch.load(f, 'cpu') for f in load_from.glob('*.pth') } self.runner.load_state_dict(state_dict, **self._load_state_dict)
@property def work_dir(self) -> pathlib.Path: return self.runner.work_dir / 'checkpoints' @property def latest_checkpoint_dir(self) -> pathlib.Path: return self._checkpoint_dir('latest') def _checkpoint_dir(self, name: str) -> pathlib.Path: return self.work_dir / name def _save(self, name: str) -> None: # for FSDP, all ranks should call state dict state_dict = self.runner.state_dict(**self._state_dict) if get_rank() != 0: return checkpoint_dir = self._checkpoint_dir(name) checkpoint_dir.mkdir(parents=True, exist_ok=True) self.runner.logger.info("Saving state dict to %s", checkpoint_dir) for k, v in state_dict.items(): torch.save(v, checkpoint_dir / f'{k}.pth') self.latest_checkpoint_dir.unlink(True) self.latest_checkpoint_dir.symlink_to(checkpoint_dir.absolute(), True)
[docs] def after_run_iter(self, batch, memo: Memo) -> None: super().after_run_iter(batch, memo) if self._should_run_iter(): self._save(f'iter_{self.runner.iter_}')
[docs] def after_run_epoch(self, epoch_memo: Memo, memo: Memo) -> None: super().after_run_epoch(epoch_memo, memo) if self._should_run_epoch(): self._save(f'epoch_{self.epoch_based_trainer.epoch}')