Source code for todd.utils.state_dicts

__all__ = [
    'StateDict',
    'Keys',
    'StateDictMixin',
    'transfer_state_dict',
    'transfer_state_dicts',
    'StateDictConverter',
]

import functools
import re
from collections import defaultdict
from itertools import starmap
from typing import Any, Mapping, NamedTuple, TypeVar

import torch
from torch import nn

from ..loggers import master_logger
from ..patches.py_ import get_
from ..patches.torch import load_state_dict_
from .misc import set_temp

T = TypeVar('T')

StateDict = dict[str, torch.Tensor]


[docs] class Keys(NamedTuple): missing: list[str] unexpected: list[str]
[docs] class StateDictMixin:
[docs] def state_dict(self, *args, **kwargs) -> dict[str, Any]: return dict()
[docs] def load_state_dict( self, state_dict: Mapping[str, Any], *args, **kwargs, ) -> Keys | None: pass
[docs] def transfer_state_dict( target: nn.Module | StateDictMixin, source: nn.Module | StateDictMixin, ) -> None: state_dict = source.state_dict() keys = target.load_state_dict(state_dict, strict=False) if keys is not None: missing, unexpected = keys master_logger.info( "\nMissing keys: %s\nUnexpected keys: %s", ', '.join(missing), ', '.join(unexpected), )
[docs] def transfer_state_dicts(models: Any, prefixes: Mapping[str, str]) -> None: for target_prefix, source_prefix in prefixes.items(): target = get_(models, target_prefix) source = get_(models, source_prefix) transfer_state_dict(target, source)
class UnknownKeyError(ValueError): def __init__(self, key: str, *args) -> None: super().__init__(f"'{key}'", *args)
[docs] class StateDictConverter:
[docs] def __init__( self, *args, module: nn.Module | None = None, **kwargs, ) -> None: super().__init__(*args, **kwargs) self._module = module self._key_mapping: dict[str, str | None] = dict() self._child_converters: dict[ str, tuple[str | None, type[StateDictConverter]], ] = dict() self._regex_converters: dict[re.Pattern[str], str | None] = dict()
@property def module(self) -> nn.Module: assert self._module is not None return self._module def _register_key_mapping(self, key: str, new_key: str | None) -> None: self._key_mapping[key] = new_key def _register_child_converter( self, child_name: str, new_child_name: str | None, child_converter_type: type['StateDictConverter'], ) -> None: assert '.' not in child_name self._child_converters[child_name] = ( new_child_name, child_converter_type, ) def _register_regex_converter( self, key_pattern: str | re.Pattern[str], new_key: str | None, ) -> None: if isinstance(key_pattern, str): key_pattern = re.compile(key_pattern) self._regex_converters[key_pattern] = new_key
[docs] def load(self, *args, **kwargs) -> StateDict: return load_state_dict_(*args, **kwargs)
def _pre_convert(self, state_dict: StateDict) -> StateDict: return state_dict def _map_keys(self, state_dict: StateDict) -> StateDict: new_state_dict: StateDict = dict() for key, new_key in self._key_mapping.items(): if key not in state_dict: continue value = state_dict.pop(key) if new_key is not None: new_state_dict[new_key] = value return new_state_dict def _convert_child(self, name: str, state_dict: StateDict) -> StateDict: new_name, converter_type = self._child_converters[name] prefix = '' if new_name is None else f'{new_name}.' if self._module is None: module = None elif new_name is None: module = self._module else: module = getattr(self._module, new_name) converter = converter_type(module=module) new_state_dict = converter.convert(state_dict) return {prefix + k: v for k, v in new_state_dict.items()} def _convert_children(self, state_dict: StateDict) -> StateDict: new_state_dict: StateDict = dict() if not self._child_converters: return new_state_dict children: defaultdict[str, StateDict] = defaultdict(dict) for key in list(state_dict): if '.' not in key: continue child_name, child_key = key.split('.', 1) if child_name not in self._child_converters: continue value = state_dict.pop(key) children[child_name][child_key] = value for child in starmap(self._convert_child, children.items()): new_state_dict.update(child) return new_state_dict def _convert_regex(self, state_dict: StateDict) -> StateDict: new_state_dict: StateDict = dict() if not self._regex_converters: return new_state_dict for key in list(state_dict): for pattern, new_key in self._regex_converters.items(): match_ = pattern.fullmatch(key) if match_ is None: continue value = state_dict.pop(key) if new_key is not None: new_key = match_.expand(new_key) new_state_dict[new_key] = value break return new_state_dict def _convert(self, key: str) -> str | None: raise UnknownKeyError(key) def _post_convert(self, state_dict: StateDict) -> StateDict: return state_dict
[docs] def convert(self, state_dict: StateDict) -> StateDict: # noqa: C901 state_dict = self._pre_convert(state_dict) # move items from state_dict to new_state_dict new_state_dict = self._map_keys(state_dict) new_state_dict |= self._convert_children(state_dict) new_state_dict |= self._convert_regex(state_dict) for key, value in state_dict.items(): new_key = self._convert(key) if new_key is not None: new_state_dict[new_key] = value new_state_dict = self._post_convert(new_state_dict) return new_state_dict
def parallel_conversion(func): @functools.wraps(func) def wrapper(self: StateDictConverter, state_dict: StateDict) -> StateDict: state_dicts: defaultdict[str, StateDict] = defaultdict(dict) for key, value in state_dict.items(): prefix, key = key.split('.', 1) state_dicts[prefix][key] = value for prefix in state_dicts: state_dicts[prefix] = func(self, state_dicts[prefix], prefix) state_dict = { f'{prefix}.{k}': v for prefix, state_dict in state_dicts.items() for k, v in state_dict.items() } return state_dict return wrapper class SequentialStateDictConverterMixin(StateDictConverter): @parallel_conversion def convert(self, state_dict: StateDict, prefix: str) -> StateDict: # noqa: E501 pylint: disable=arguments-differ module = self._module assert isinstance(module, nn.Sequential) with set_temp(self, '._module', module[int(prefix)]): return super().convert(state_dict)