# pylint: disable=using-constant-test
__all__ = [
'init_seed',
'set_seed_temp',
]
import hashlib
import random
from contextlib import contextmanager
from typing import Generator
import numpy as np
import torch
from torch.backends import cudnn
from ..loggers import logger
from ..patches.torch import random_int
from .stores import Store
[docs]
def init_seed(seed: int) -> None:
seed %= 2**30
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if Store.cuda:
torch.cuda.manual_seed(seed)
[docs]
@contextmanager
def set_seed_temp(
seed=None,
deterministic: bool = False,
benchmark: bool = True,
) -> Generator[None, None, None]:
if seed is None:
seed = random_int()
elif isinstance(seed, int):
pass
else:
if not isinstance(seed, bytes):
seed = str(seed).encode()
seed = hashlib.blake2b(seed, digest_size=4).hexdigest()
seed = int(seed, 16)
logger.info("Setting seed to %d", seed)
random_state = random.getstate()
np_state = np.random.get_state()
torch_state = torch.get_rng_state()
if Store.cuda:
cuda_state = torch.cuda.get_rng_state()
prev_deterministic = cudnn.deterministic
prev_benchmark = cudnn.benchmark
cudnn.deterministic = deterministic
cudnn.benchmark = benchmark
init_seed(seed)
yield
random.setstate(random_state)
np.random.set_state(np_state)
torch.set_rng_state(torch_state)
if Store.cuda:
torch.cuda.set_rng_state(cuda_state)
cudnn.deterministic = prev_deterministic
cudnn.benchmark = prev_benchmark