Source code for todd.utils.ema

__all__ = [
    'ema',
    'EMA',
]

from typing import Any


[docs] def ema(x: Any, y: Any, decay: Any) -> Any: return x * decay + y * (1 - decay)
[docs] class EMA:
[docs] def __init__(self, decay: Any = 0.99) -> None: self._decay = decay
[docs] @classmethod def check_decay(cls, decay: Any) -> None: assert 0 <= decay <= 1
@property def decay(self) -> Any: return self._decay def __call__(self, x: Any, y: Any) -> Any: if x is None: assert y is not None return y if y is None: return x return ema(x, y, self._decay)