Source code for todd.models.mean_std

__all__ = [
    'MeanStdMixin',
]

from abc import ABC

import einops
import torch
from torch import nn


[docs] class MeanStdMixin(nn.Module, ABC):
[docs] def __init__( self, *args, mean: tuple[float, float, float], std: tuple[float, float, float], **kwargs, ) -> None: super().__init__(*args, **kwargs) mean_ = torch.tensor(mean) mean_ = einops.rearrange(mean_, 'c -> 1 c 1 1') std_ = torch.tensor(std) std_ = einops.rearrange(std_, 'c -> 1 c 1 1') self.register_buffer('_mean', mean_) self.register_buffer('_std', std_)
@property def mean(self) -> torch.Tensor: return self.get_buffer('_mean') @property def std(self) -> torch.Tensor: return self.get_buffer('_std')
[docs] def normalize(self, image: torch.Tensor) -> torch.Tensor: image = (image - self.mean) / self.std return image
[docs] def denormalize(self, image: torch.Tensor) -> torch.Tensor: image = image * self.std + self.mean return image