Created
December 9, 2023 00:07
-
-
Save Oil3/92a93d36f00cc76a8b18c99783bc7c9c to your computer and use it in GitHub Desktop.
mps pytorch from some codeformer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ | |
torch.__version__)[0][:3])] >= [1, 12, 0] | |
def gpu_is_available(): | |
if IS_HIGH_VERSION: | |
if torch.backends.mps.is_available(): | |
return True | |
return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False | |
def get_device(gpu_id=None): | |
if gpu_id is None: | |
gpu_str = '' | |
elif isinstance(gpu_id, int): | |
gpu_str = f':{gpu_id}' | |
else: | |
raise TypeError('Input should be int value.') | |
if IS_HIGH_VERSION: | |
if torch.backends.mps.is_available(): | |
return torch.device('mps'+gpu_str) | |
return torch.device('cuda'+gpu_str if torch.c |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment