Source code for todd.models.losses.functional

__all__ = [
    'FunctionalLoss',
    'NormMixin',
    'L1Loss',
    'MSELoss',
    'BCELoss',
    'BCEWithLogitsLoss',
    'CrossEntropyLoss',
    'CosineEmbeddingLoss',
]

from abc import ABC
from typing import Callable

import torch
import torch.nn.functional as F

from ...bases.configs import Config
from ...bases.registries import Item, RegistryMeta
from ...registries import ModelRegistry
from ..registries import LossRegistry
from .base import BaseLoss, Reduction


[docs] @LossRegistry.register_() class FunctionalLoss(BaseLoss):
[docs] def __init__( self, *args, func: Callable[..., torch.Tensor], **kwargs, ) -> None: super().__init__(*args, **kwargs) self._func = func
[docs] @classmethod def build_pre_hook( cls, config: Config, registry: RegistryMeta, item: Item, ) -> Config: config = super().build_pre_hook(config, registry, item) if 'func' in config: config.func = ModelRegistry.build_or_return(config.func) return config
[docs] def forward( # pylint: disable=arguments-differ self, pred: torch.Tensor, target: torch.Tensor, *args, mask: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: if mask is None: return self._func( pred, target, *args, reduction=self.reduction, **kwargs, ) loss = self._func( pred, target, *args, reduction=Reduction.NONE.value, **kwargs, ) return self._reduce(loss, mask)
[docs] class NormMixin(FunctionalLoss, ABC):
[docs] def __init__(self, *args, norm: bool = False, **kwargs) -> None: super().__init__(*args, **kwargs) self._norm = norm
[docs] def forward( self, pred: torch.Tensor, target: torch.Tensor, *args, mask: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: if self._norm: pred = F.normalize(pred) target = F.normalize(target) return super().forward(pred, target, *args, mask=mask, **kwargs)
[docs] @LossRegistry.register_() class L1Loss(NormMixin, FunctionalLoss):
[docs] def __init__(self, *args, **kwargs) -> None: super().__init__(*args, func=F.l1_loss, **kwargs)
[docs] @LossRegistry.register_() class MSELoss(NormMixin, FunctionalLoss):
[docs] def __init__(self, *args, **kwargs) -> None: super().__init__(*args, func=F.mse_loss, **kwargs)
[docs] @LossRegistry.register_() class BCELoss(FunctionalLoss):
[docs] def __init__(self, *args, **kwargs) -> None: super().__init__(*args, func=F.binary_cross_entropy, **kwargs)
[docs] @LossRegistry.register_() class BCEWithLogitsLoss(FunctionalLoss):
[docs] def __init__(self, *args, **kwargs) -> None: super().__init__( *args, func=F.binary_cross_entropy_with_logits, **kwargs, )
[docs] @LossRegistry.register_() class CrossEntropyLoss(FunctionalLoss):
[docs] def __init__(self, *args, **kwargs) -> None: super().__init__(*args, func=F.cross_entropy, **kwargs)
[docs] @LossRegistry.register_() class CosineEmbeddingLoss(FunctionalLoss):
[docs] def __init__(self, *args, **kwargs) -> None: super().__init__(*args, func=F.cosine_embedding_loss, **kwargs)