Created
June 17, 2024 23:44
-
-
Save larroy/f6070a854e47f597c3beedc7b3471ac6 to your computer and use it in GitHub Desktop.
Get number of gpus
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def gpu_count() -> int: | |
"""Get number of gpus.""" | |
import ctypes | |
try: | |
cuda = ctypes.CDLL("libcuda.so") | |
except OSError: | |
logging.exception("Loading libcuda") | |
return 0 | |
CUDA_SUCCESS = 0 | |
CUDA_ERROR_NO_DEVICE = 100 | |
num_gpus = ctypes.c_int() | |
result = ctypes.c_int() | |
result = cuda.cuInit(0) | |
if result != CUDA_SUCCESS: | |
logging.error("cuInit returned an error: %d", result) | |
return 0 | |
result = cuda.cuDeviceGetCount(ctypes.byref(num_gpus)) | |
if result != CUDA_SUCCESS: | |
logging.error("cuDeviceGetCount returned an error: %d", result.value) | |
return 0 | |
return num_gpus.value |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment