Source code for todd.models.shadows.base

__all__ = [
    'BaseShadow',
]

from abc import ABC, abstractmethod
from functools import partial
from typing import Any

import torch
from torch import nn

from ...utils import NestedTensorCollectionUtils, StateDict
from ..norms import BATCHNORMS


[docs] class BaseShadow(nn.Module, ABC):
[docs] def __init__( self, *args, module: nn.Module, device: Any = None, **kwargs, ) -> None: # BN layers are not supported assert not any(isinstance(m, BATCHNORMS) for m in module.modules()) super().__init__(*args, **kwargs) self._device = device self._shadow = self._state_dict_to_device(module)
@property def shadow(self) -> StateDict: return self._shadow @shadow.setter def shadow(self, value: StateDict) -> None: self._shadow = self._to_device(value) def _to_device(self, state_dict: StateDict) -> StateDict: if self._device is None: return state_dict utils = NestedTensorCollectionUtils() f = partial(torch.Tensor.to, device=self._device) return utils.map(f, state_dict) def _state_dict_to_device(self, module: nn.Module) -> StateDict: return self._to_device(module.state_dict()) @abstractmethod def _forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: pass
[docs] def forward(self, module: nn.Module) -> None: utils = NestedTensorCollectionUtils() self._shadow = utils.map( self._forward, # type: ignore[arg-type] self._shadow, self._state_dict_to_device(module), )