Source code for todd.models.losses.base

__all__ = [
    'BaseLoss',
]

from abc import ABC, abstractmethod
from enum import StrEnum

import torch
from torch import nn

from ...bases.configs import Config
from ...bases.registries import BuildPreHookMixin, Item, RegistryMeta
from .schedulers import BaseScheduler, SchedulerRegistry


class Reduction(StrEnum):
    NONE = 'none'
    MEAN = 'mean'
    SUM = 'sum'
    PROD = 'prod'
    WEIGHTED = 'weighted'


[docs] class BaseLoss(BuildPreHookMixin, nn.Module, ABC):
[docs] def __init__( self, reduction: str | Reduction = Reduction.MEAN, weight: float | BaseScheduler = 1.0, bound: float | None = None, **kwargs, ) -> None: super().__init__(**kwargs) self._reduction = ( reduction if isinstance(reduction, Reduction) else Reduction(reduction.lower()) ) self._weight = ( weight if isinstance(weight, BaseScheduler) else BaseScheduler(gain=weight) ) self._bound = bound self.register_forward_hook(lambda m, i, o: self._scale(o))
[docs] @classmethod def build_pre_hook( cls, config: Config, registry: RegistryMeta, item: Item, ) -> Config: config = super().build_pre_hook(config, registry, item) if 'weight' in config: config.weight = SchedulerRegistry.build_or_return(config.weight) return config
@property def reduction(self) -> Reduction: return self._reduction @property def weight(self) -> float: return self._weight()
[docs] def step(self) -> None: self._weight.step()
def _reduce( self, loss: torch.Tensor, mask: torch.Tensor | None = None, ) -> torch.Tensor: if mask is not None: loss = loss * mask if self._reduction is Reduction.NONE: return loss if self._reduction is Reduction.WEIGHTED: assert mask is not None weight = mask.sum() if weight.abs() < 1e-6: return loss.new_zeros([]) return loss.sum() / weight return getattr(loss, self._reduction.value)() def _scale(self, loss: torch.Tensor) -> torch.Tensor: weight = self.weight if self._bound is not None: coef = self._bound / (weight * loss.item()) weight *= min(coef, 1.) # weight <= bound / loss return weight * loss
[docs] @abstractmethod def forward(self, *args, **kwargs) -> torch.Tensor: pass