Source code for todd.patches.torch.distributed

__all__ = [
    'get_rank',
    'get_local_rank',
    'get_world_size',
    'all_gather',
    'all_gather_object',
]

import functools
import operator
import os

import torch
import torch.distributed as dist


[docs] def get_rank(*args, **kwargs) -> int: if dist.is_initialized(): return dist.get_rank(*args, **kwargs) return 0
[docs] def get_local_rank(*args, **kwargs) -> int: if 'LOCAL_RANK' in os.environ: return int(os.environ['LOCAL_RANK']) return get_rank(*args, **kwargs)
[docs] def get_world_size(*args, **kwargs) -> int: if dist.is_initialized(): return dist.get_world_size(*args, **kwargs) return 1
def all_gather_decorator(func): @functools.wraps(func) def wrapper(tensor: torch.Tensor, *args, **kwargs) -> list[torch.Tensor]: tensor = tensor.detach().clone() world_size = get_world_size() if world_size <= 1: return [tensor] return func(tensor, world_size, *args, **kwargs) return wrapper
[docs] @all_gather_decorator def all_gather( tensor: torch.Tensor, world_size: int, *args, **kwargs, ) -> list[torch.Tensor]: tensors = [torch.empty_like(tensor) for _ in range(world_size)] dist.all_gather(tensors, tensor, *args, **kwargs) return tensors
[docs] @all_gather_decorator def all_gather_object( tensor: torch.Tensor, world_size: int, *args, **kwargs, ) -> list[torch.Tensor]: shape = tensor.shape shapes = [shape] * world_size dist.all_gather_object(shapes, shape) numel_list = [functools.reduce(operator.mul, s, 1) for s in shapes] max_numel = max(numel_list) container = tensor.new_empty(max_numel) containers = [torch.empty_like(container) for _ in range(world_size)] container[:tensor.numel()] = tensor.view(-1) dist.all_gather(containers, container, *args, **kwargs) tensors = [ container[:numel].view(shape) for shape, numel, container in zip(shapes, numel_list, containers) ] return tensors