Shadows
- class todd.models.shadows.BaseShadow[source]
- class todd.models.shadows.EMAShadow[source]
Bases:
BuildPreHookMixin,BaseShadowExponential Moving Average (EMA) Shadow.
This class represents a shadow model that applies exponential moving average to the input data.
- Parameters:
ema – The EMA object used for applying exponential moving average.
A copy of the state dict of the given module is stored as the initial shadow:
>>> import torch.nn as nn >>> module = nn.Module() >>> module.register_buffer('p', torch.tensor([1., 2., 3.])) >>> ema = EMAShadow(module=module, ema=EMA()) >>> dict(ema.shadow) {'p': tensor([1., 2., 3.])}
The shadow updates according to the model:
>>> module.register_buffer('p', torch.tensor([4., 5., 6.])) >>> ema(module) >>> dict(ema.shadow) {'p': tensor([1.0300, 2.0300, 3.0300])}