Source code for todd.runners.strategies.ddp

__all__ = [
    'DDPStrategy',
]

from typing import TypeVar, cast

from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP

from ...bases.configs import Config
from ..registries import StrategyRegistry
from .cuda import CUDAStrategy

T = TypeVar('T', bound=DDP)


[docs] @StrategyRegistry.register_() class DDPStrategy(CUDAStrategy[T]):
[docs] def wrap_model(self, model: nn.Module, config: Config) -> T: model = super().wrap_model(model, config) model = DDP(model, **config) return cast(T, model)
@property def module(self) -> nn.Module: return self.runner.model.module