__all__ = [
'Statistician',
]
import torch
[docs]
class Statistician:
[docs]
def __init__(self, chunk_size: int) -> None:
self._chunk_size = chunk_size
self._means: list[torch.Tensor] = []
self._variances: list[torch.Tensor] = []
self._running_chunk: list[torch.Tensor] = []
@property
def _running_chunk_size(self) -> int:
return sum(samples.shape[0] for samples in self._running_chunk)
@property
def num_samples(self) -> int:
num_chunks, = {len(self._means), len(self._variances)} # noqa: E501 pylint: disable=unbalanced-tuple-unpacking
return num_chunks * self._chunk_size + self._running_chunk_size
[docs]
def update(self, samples: torch.Tensor) -> None:
assert samples.dim() == 2
self._running_chunk.append(samples)
if self._running_chunk_size < self._chunk_size:
return
running_chunk = torch.cat(self._running_chunk)
while running_chunk.shape[0] >= self._chunk_size:
chunk = running_chunk[:self._chunk_size]
self._means.append(chunk.mean(0))
self._variances.append(chunk.var(0))
running_chunk = running_chunk[self._chunk_size:]
if running_chunk.shape[0]:
self._running_chunk = [running_chunk]
def _weighted_average(
self,
value: torch.Tensor | None,
running_value: torch.Tensor | None,
) -> torch.Tensor:
if value is None and running_value is None:
raise RuntimeError("No samples to compute")
if value is None:
assert running_value is not None
return running_value
if running_value is None:
assert value is not None
return value
w = self._running_chunk_size / self.num_samples
return (1 - w) * value + w * running_value
[docs]
def compute_mean(self) -> torch.Tensor:
mean = torch.stack(self._means).mean(0) if self._means else None
running_mean = (
torch.cat(self._running_chunk).mean(0)
if self._running_chunk else None
)
return self._weighted_average(mean, running_mean)
[docs]
def compute_variance(
self,
mean: torch.Tensor | None = None,
) -> torch.Tensor:
if mean is None:
mean = self.compute_mean()
if self._variances:
variances = [
chunk_variance + (chunk_mean - mean)**2 for chunk_mean,
chunk_variance in zip(self._means, self._variances)
]
variance = torch.stack(variances).mean(0)
else:
variance = None
if self._running_chunk:
running_chunk = torch.cat(self._running_chunk)
squared_deviation = (running_chunk - mean)**2
running_variance = squared_deviation.mean(0)
else:
running_variance = None
return self._weighted_average(variance, running_variance)