Source code for todd.patches.torch.aten

__all__ = [
    'all_close',
    'random_int',
    'cosine_similarity',
]

from typing import Any, cast

import torch
import torch.distributed as dist
import torch.nn.functional as F

from .distributed import get_world_size


[docs] def random_int() -> int: seed = torch.randint(2**30, []) if get_world_size() > 1: dist.broadcast(seed, src=0) return cast(int, seed.item())
[docs] def all_close(x: Any, y: Any, *args, **kwargs) -> bool: if not isinstance(x, torch.Tensor): x = torch.tensor(x) if not isinstance(y, torch.Tensor): y = torch.tensor(y) return torch.allclose(x, y, *args, **kwargs)
[docs] def cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # fixed CUDA out of memory error of torch.cosine_similarity x = F.normalize(x) y = F.normalize(y) return torch.einsum('x d, y d -> x y', x, y)