Source code for todd.registries.model

__all__ = [
    'InitWeightsMixin',
    'ModelRegistry',
]

from abc import ABC, abstractmethod
from typing import Any, Callable, cast

# include einops layers in ModelRegistry
import einops.layers.torch  # noqa: F401 pylint: disable=unused-import
from torch import nn

from ..bases.configs import Config
from ..bases.registries import Item, Registry, RegistryMeta
from ..loggers import master_logger
from ..patches.py_ import descendant_classes


[docs] class InitWeightsMixin(nn.Module, ABC):
[docs] @abstractmethod def init_weights(self, config: Config) -> bool: if hasattr(super(), 'init_weights'): return super().init_weights(config) # type: ignore[misc] return True
[docs] class ModelRegistry(Registry):
[docs] @classmethod def init_weights( cls, model: nn.Module, config: 'Config | None', prefix: str = '', ) -> None: weights = f"{model.__class__.__name__} ({prefix}) weights" if getattr(model, '__initialized', False): master_logger.debug("Skip re-initializing %s", weights) return setattr(model, '__initialized', True) # noqa: B010 if config is None: master_logger.debug( "Skip initializing %s since config is None", weights, ) return init_weights: Callable[[Config], bool] | None = \ getattr(model, 'init_weights', None) if init_weights is not None: master_logger.debug("Initializing %s with %s", weights, config) recursive = init_weights(config) if not recursive: return for ( name, # noqa: E501 pylint: disable=redefined-outer-name child, ) in model.named_children(): cls.init_weights(child, config, f'{prefix}.{name}')
@classmethod def _build(cls, item: Item, config: Config) -> Any: config = config.copy() init_weights = config.pop('init_weights', Config()) model = RegistryMeta._build(cls, item, config) if isinstance(model, nn.Module): cls.init_weights(model, init_weights) return model
for c in descendant_classes(nn.Module): name = ( # pylint: disable=invalid-name c.__module__.replace('.', '_') + '_' + c.__name__ ) if name not in ModelRegistry: ModelRegistry.register_(name)(cast(Item, c))