Created
April 5, 2018 07:56
-
-
Save ed-alertedh/58d3eb96cf1ba70542b657471dd377ca to your computer and use it in GitHub Desktop.
Restrict CUDA device availability from python - adapted from Yaroslavv Bulatav's code
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import py3nvml | |
# GPU picking | |
# http://stackoverflow.com/a/41638727/419116 | |
# credit to Yaroslavv Bulatav for the general idea, the error checking and CUDA_DEVICE_ORDER workaround | |
# https://github.com/yaroslavvb/stuff/blob/252668efe602da30a34948bfb8b4f89f1ef9d60a/notebook_util.py | |
# This uses py3nvml to avoid issues parsing the output of nvidia-smi on different platforms | |
_already_initialised = False | |
def setup_gpus(num_gpus, memory_fraction=0.8): | |
global _already_initialised | |
if not _already_initialised: | |
assert 'tensorflow' not in sys.modules, 'GPU setup must happen before importing TensorFlow' | |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see tensorflow issue #152 | |
num_grabbed = py3nvml.grab_gpus(num_gpus, gpu_fraction=memory_fraction) | |
assert num_grabbed == num_gpus, 'Could not grab {} GPU devices with {}% memory available'.format(num_gpus, | |
memory_fraction * 100) | |
if os.environ['CUDA_VISIBLE_DEVICES'] == "": | |
os.environ['CUDA_VISIBLE_DEVICES'] = "-1" # see tensorflow issues: #16824, #2175 | |
_already_initialised = True |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment