Source code for todd.models.norms

__all__ = [
    'BATCHNORMS',
    'AdaptiveGroupNorm',
    'AdaptiveLayerNorm',
]

import math
from abc import ABC, abstractmethod

import torch
from torch import nn
from torch.nn.modules import batchnorm

from .registries import NormRegistry

NormRegistry.update(
    BN1d=nn.BatchNorm1d,
    BN2d=nn.BatchNorm2d,
    BN=nn.BatchNorm2d,
    BN3d=nn.BatchNorm3d,
    SyncBN=nn.SyncBatchNorm,
    GN=nn.GroupNorm,
    LN=nn.LayerNorm,
    IN1d=nn.InstanceNorm1d,
    IN2d=nn.InstanceNorm2d,
    IN=nn.InstanceNorm2d,
    IN3d=nn.InstanceNorm3d,
)

BATCHNORMS = (
    batchnorm.BatchNorm1d,
    batchnorm.BatchNorm2d,
    batchnorm.BatchNorm3d,
    batchnorm.SyncBatchNorm,
)


class AdaptiveMixin(nn.Module, ABC):
    _linear: nn.Linear

    def __init__(self, *args, condition_dim: int, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self._condition_dim = condition_dim

    @abstractmethod
    def _forward(
        self,
        x: torch.Tensor,
        condition: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        pass

    def forward(
        self,
        *args,
        condition: torch.Tensor,
        **kwargs,
    ) -> torch.Tensor:
        x: torch.Tensor = super().forward(*args, **kwargs)
        condition = self._linear(condition)
        weight, bias = self._forward(x, condition)
        return x * (1 + weight) + bias


[docs] @NormRegistry.register_('AdaGN') class AdaptiveGroupNorm(AdaptiveMixin, nn.GroupNorm): # type: ignore[misc]
[docs] def __init__(self, *args, **kwargs) -> None: super().__init__(*args, affine=False, **kwargs) self._linear = nn.Linear(self._condition_dim, self.num_channels * 2)
def _forward( self, x: torch.Tensor, condition: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: assert x.dim() >= 2 assert condition.dim() == 2 shape = condition.shape + (1, ) * x.dim() condition = condition.reshape(shape[:x.dim()]) weight, bias = condition.chunk(2, 1) return weight, bias
[docs] @NormRegistry.register_('AdaLN') class AdaptiveLayerNorm(AdaptiveMixin, nn.LayerNorm): # type: ignore[misc]
[docs] def __init__(self, *args, **kwargs) -> None: super().__init__( *args, elementwise_affine=False, **kwargs, ) self._linear = nn.Linear( self._condition_dim, math.prod(self.normalized_shape) * 2, )
def _forward( self, x: torch.Tensor, condition: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: shape: tuple[int, ...] = condition.shape[:-1] assert len(shape) + len(self.normalized_shape) <= x.dim() shape = shape + (1, ) * x.dim() shape = shape[:x.dim()] shape = shape[:-len(self.normalized_shape)] + self.normalized_shape weight, bias = condition.chunk(2, -1) return weight.reshape(shape), bias.reshape(shape)