__all__ = [
'ModuleList',
'ModuleDict',
'Sequential',
'training_modules',
'named_training_modules',
'trainable_parameters',
'named_trainable_parameters',
'load_state_dict',
'load_state_dict_',
]
import logging
from typing import Any, Generator, Mapping
import torch
from torch import nn
from .serialization import load
[docs]
class ModuleList(nn.ModuleList):
[docs]
def forward(self, *args, **kwargs) -> list[nn.Module]:
return [m(*args, **kwargs) for m in self]
[docs]
class ModuleDict(nn.ModuleDict):
[docs]
def forward(self, *args, **kwargs) -> dict[str, nn.Module]:
return {k: m(*args, **kwargs) for k, m in self.items()}
[docs]
class Sequential(nn.Sequential):
[docs]
def __init__(self, *args, unpack_args: bool = False, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._unpack_args = unpack_args
[docs]
def forward(self, *args, **kwargs) -> tuple[Any, ...]:
if not self._unpack_args:
args, = args
for m in self:
args = (
m(*args, **kwargs) if self._unpack_args else m(args, **kwargs)
)
return args
[docs]
def training_modules(
module: nn.Module,
*args,
**kwargs,
) -> Generator[nn.Module, None, None]:
for m in module.modules(*args, **kwargs):
if m.training:
yield m
[docs]
def named_training_modules(
module: nn.Module,
*args,
**kwargs,
) -> Generator[tuple[str, nn.Module], None, None]:
for name, m in module.named_modules(*args, **kwargs):
if m.training:
yield name, m
[docs]
def trainable_parameters(
module: nn.Module,
*args,
**kwargs,
) -> Generator[nn.Parameter, None, None]:
for parameter in module.parameters(*args, **kwargs):
if parameter.requires_grad:
yield parameter
[docs]
def named_trainable_parameters(
module: nn.Module,
*args,
**kwargs,
) -> Generator[tuple[str, nn.Parameter], None, None]:
for name, parameter in module.named_parameters(*args, **kwargs):
if parameter.requires_grad:
yield name, parameter
[docs]
def load_state_dict(
module: nn.Module,
state_dict: Mapping[str, Any],
*args,
logger: logging.Logger | None = None,
**kwargs,
) -> None:
if logger is None:
from ...loggers import master_logger as logger
assert logger is not None
incompatible_keys = module.load_state_dict(state_dict, *args, **kwargs)
logger.info(incompatible_keys)
[docs]
def load_state_dict_(
f: torch.serialization.FileLike | list[torch.serialization.FileLike],
*args,
logger: logging.Logger | None = None,
**kwargs,
) -> dict[str, Any]:
f_list = f if isinstance(f, list) else [f]
if logger is None:
from ...loggers import master_logger as logger
assert logger is not None
state_dict: dict[str, Any] = dict()
for f_ in f_list:
logger.info("Loading model from %s", f_)
state_dict.update(load(f_, 'cpu', *args, **kwargs))
return state_dict