Skip to content

Instantly share code, notes, and snippets.

@vadim0x60
Last active November 20, 2024 12:24
Show Gist options
  • Save vadim0x60/5b558e0a552627efb1c2d54b90a7cec6 to your computer and use it in GitHub Desktop.
Save vadim0x60/5b558e0a552627efb1c2d54b90a7cec6 to your computer and use it in GitHub Desktop.
Choose the right Pytorch device always
def torch_device(priority=['xla', 'cuda', 'mps', 'xpu', 'cpu']):
import torch
t = torch.Tensor([0])
for device in map(torch.device, priority):
try:
t.to(device)
return device
except (RuntimeError, AssertionError):
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment