Last active
June 6, 2024 03:39
-
-
Save pszemraj/6b57610a49db9449e4784ac0614d2f6c to your computer and use it in GitHub Desktop.
modern way to auto enable tf32
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import logging | |
def check_ampere_gpu(): | |
""" | |
Check if the GPU supports NVIDIA Ampere or later and enable FP32 in PyTorch if it does. | |
""" | |
# Check if CUDA is available | |
if not torch.cuda.is_available(): | |
logging.info("No GPU detected, running on CPU.") | |
return | |
try: | |
# Get the compute capability of the GPU | |
device = torch.cuda.current_device() | |
capability = torch.cuda.get_device_capability(device) | |
major, minor = capability | |
# Check if the GPU is Ampere or newer (compute capability >= 8.0) | |
if major >= 8: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
gpu_name = torch.cuda.get_device_name(device) | |
print( | |
f"{gpu_name} (compute capability {major}.{minor}) supports NVIDIA Ampere or later, enabled TF32 in PyTorch." | |
) | |
else: | |
gpu_name = torch.cuda.get_device_name(device) | |
print( | |
f"{gpu_name} (compute capability {major}.{minor}) does not support NVIDIA Ampere or later." | |
) | |
except Exception as e: | |
logging.warning(f"Error occurred while checking GPU: {e}") | |
check_ampere_gpu() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment