Source code for todd.patches.torch.misc
__all__ = [
'get_device',
]
import torch
[docs]
def get_device() -> str:
if torch.cuda.is_available():
return 'cuda'
if torch.backends.mps.is_available():
return 'mps'
return 'cpu'