Source code for todd.runners.utils.runner_holder

__all__ = [
    'RunnerHolderMixin',
]

from typing import TypeVar

from torch import nn

from ...utils import HolderMixin
from ..base import BaseRunner
from ..epoch_based_trainer import EpochBasedTrainer
from ..iter_based_trainer import IterBasedTrainer
from ..trainer import Trainer
from ..validator import Validator

T = TypeVar('T', bound=nn.Module)


[docs] class RunnerHolderMixin(HolderMixin[BaseRunner[T]]):
[docs] def __init__( self, *args, runner: BaseRunner[T] | None = None, **kwargs, ) -> None: super().__init__(*args, instance=runner, **kwargs)
@property def runner(self) -> BaseRunner[T]: return self._instance @property def trainer(self) -> Trainer[T]: assert isinstance(self._instance, Trainer) return self._instance @property def validator(self) -> Validator[T]: assert isinstance(self._instance, Validator) return self._instance @property def iter_based_trainer(self) -> IterBasedTrainer[T]: assert isinstance(self._instance, IterBasedTrainer) return self._instance @property def epoch_based_trainer(self) -> EpochBasedTrainer[T]: assert isinstance(self._instance, EpochBasedTrainer) return self._instance