Source code for todd.utils.stores

__all__ = [
    'StoreMeta',
    'Store',
]

import os
from typing import Any

from ..loggers import logger
from ..patches.py_ import NonInstantiableMeta, classproperty
from ..patches.torch import get_device


[docs] class StoreMeta(NonInstantiableMeta): """Stores for global variables. Stores provide an interface to access global variables: >>> class CustomStore(metaclass=StoreMeta): ... VARIABLE: int >>> CustomStore.VARIABLE 0 >>> CustomStore.VARIABLE = 1 >>> CustomStore.VARIABLE 1 Variables can have explicit default values: >>> class DefaultStore(metaclass=StoreMeta): ... DEFAULT: float = 0.625 >>> DefaultStore.DEFAULT 0.625 Non-empty environment variables are read-only. For string variables, their values are read directly from the environment. Other environment variables are evaluated and should be of the corresponding type. Default values are ignored. >>> os.environ['ENV_INT'] = '2' >>> os.environ['ENV_STR'] = 'hello world!' >>> os.environ['ENV_DICT'] = 'dict(a=1)' >>> class EnvStore(metaclass=StoreMeta): ... ENV_INT: int = 1 ... ENV_STR: str ... ENV_DICT: dict >>> EnvStore.ENV_INT 2 >>> EnvStore.ENV_STR 'hello world!' >>> EnvStore.ENV_DICT {'a': 1} Assignments to those variables will not trigger exceptions, but will not take effect: >>> EnvStore.ENV_INT = 3 >>> EnvStore.ENV_INT 2 """
[docs] def __init__(cls, *args, **kwargs) -> None: super().__init__(*args, **kwargs) for k, v in cls.__annotations__.items(): if not hasattr(cls, k): setattr(cls, k, v())
def _overridden(cls, name: str) -> bool: return name in cls.__annotations__ and name in os.environ def __getattribute__(cls, name: str) -> Any: if ( name in ['__annotations__', '_overridden'] # pylint: disable=no-value-for-parameter or not cls._overridden(name) ): return super().__getattribute__(name) type_ = cls.__annotations__[name] variable = os.environ[name] if type_ is not str: variable = eval(variable) # nosec B307 assert isinstance(variable, type_) return variable def __setattr__(cls, name: str, value) -> None: if not cls._overridden(name): # pylint: disable=no-value-for-parameter super().__setattr__(name, value) return logger.debug("Cannot set %s to %s.", name, value) def __repr__(cls) -> str: variables = ' '.join( f'{k}={getattr(cls, k)}' for k in cls.__annotations__ ) return f"<{cls.__name__} {variables}>"
[docs] class Store(metaclass=StoreMeta): DEVICE: str = get_device() DRY_RUN: bool TRAIN_WITH_VAL_DATASET: bool @classmethod def _device(cls, name: str) -> bool: return cls.DEVICE == name @classproperty def cpu(self) -> bool: return self._device('cpu') @classproperty def cuda(self) -> bool: return self._device('cuda') @classproperty def mps(self) -> bool: return self._device('mps')