Source code for todd.runners.strategies.cuda

__all__ = [
    'CUDAStrategy',
]

from typing import TypeVar

import torch
import torch.distributed as dist
from torch import nn

from ...bases.configs import Config
from ...patches.torch import get_local_rank
from ...utils import Store
from ..registries import StrategyRegistry
from .base import BaseStrategy

T = TypeVar('T', bound=nn.Module)


[docs] @StrategyRegistry.register_() class CUDAStrategy(BaseStrategy[T]):
[docs] def setup(self, config: Config) -> None: assert Store.cuda if not dist.is_initialized(): init_process_group = config.get( 'init_process_group', Config(backend='nccl'), ) dist.init_process_group(**init_process_group) torch.cuda.set_device(get_local_rank() % torch.cuda.device_count())
[docs] def map_model(self, model: nn.Module, config: Config) -> nn.Module: model = super().map_model(model, config) return model.cuda()