Last active
June 3, 2020 07:07
-
-
Save rueberger/d16da35098c09f9369275d2a83e8b138 to your computer and use it in GitHub Desktop.
A quick way to distribute embarrassingly parallel things with ipyparallel on multiple gpus (or other things)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
WS_N_GPUS = { | |
'turagas-ws1': 2, | |
'turagas-ws2': 2, | |
'turagas-ws3': 2, | |
'turagas-ws4': 2, | |
'c04u01': 8, | |
'c04u07': 8, | |
'c04u12': 8, | |
'c04u17': 8, | |
} | |
def gpu_job_runner(job_fnc, job_args, ipp_profile='ssh_gpu_py2', log_name=None, | |
status_interval=600, allow_engine_overlap=True): | |
""" Distribute a set of jobs across an IPyParallel 'GPU cluster' | |
Requires that cluster has already been started with `ipcluster start --profile={}`.forat(ipp_profile) | |
Checks on the jobs every status_interval seconds, logging status. | |
Args: | |
job_fnc: the function to distribute | |
must accept `device` as a kwarg, as this function is wrapped so that | |
device is bound within the engine namespace | |
returned values are ignored | |
job_args: list of args passed to job_fnc - list | |
ipp_profile: profile of GPU IPyParallel profile - str | |
log_name: (optional) name for log | |
status_interval: (optional) the amount of time, in seconds, to wait before querying the AsyncResult | |
object for the status of the jobs | |
""" | |
from ipyparallel import Client, interactive, RemoteError, Reference | |
import inspect | |
# setup logging | |
log_path = os.path.expanduser("~/logs/deepspike/job_runner.log") | |
log_name = log_name or 'job_runner' | |
logger = setup_logging(log_name, log_path) | |
# TODO: this isn't strictly necessary | |
try: | |
# check that job_fnc accepts a device kwarg | |
args = inspect.getargspec(job_fnc)[0] | |
assert 'device' in args | |
except AssertionError: | |
logger.critical("job_fnc does not except device kwarg. Halting.") | |
client = Client(profile=ipp_profile) | |
logger.info("Succesfully initialized client on %s with %s engines", ipp_profile, len(client)) | |
# assign each engine to a GPU | |
engines_per_host = {} | |
device_assignments = [] | |
engine_hosts = client[:].apply(socket.gethostname).get() | |
for host in engine_hosts: | |
if host in engines_per_host: | |
device_assignments.append('/gpu:{}'.format(engines_per_host[host])) | |
engines_per_host[host] += 1 | |
else: | |
device_assignments.append('/gpu:0') | |
engines_per_host[host] = 1 | |
logger.info("Engines per host: \n") | |
if not allow_engine_overlap: | |
try: | |
# check that we haven't over-provisioned GPUs | |
for host, n_engines in engines_per_host.iteritems(): | |
logger.info("%s: %s", host, n_engines) | |
assert n_engines <= WS_N_GPUS[host] | |
except AssertionError: | |
logger.critical("Host has more engines than GPUs. Halting.") | |
while True: | |
try: | |
# NOTE: could also be accomplished with process environment variables | |
# broadcast device assignments and job_fnc | |
for engine_id, engine_device in enumerate(device_assignments): | |
print("Pushing to engine {}: device: {}".format(engine_id, engine_device)) | |
client[engine_id].push({'device': engine_device, | |
'job_fnc': job_fnc}, block=True) | |
for engine_id, (host, assigned_device) in enumerate(zip(engine_hosts, device_assignments)): | |
remote_device = client[engine_id].pull('device').get() | |
logger.info("Engine %s: host = %s; device = %s, remote device = %s", | |
engine_id, host, assigned_device, remote_device) | |
break | |
except RemoteError as remote_err: | |
logger.warn("Caught remote error: %s. Sleeping for 10s before retry", remote_err) | |
time.sleep(10) | |
logger.info("Dispatching jobs: %s", job_args) | |
# dispatch jobs | |
async_result = client[:].map(job_fnc, job_args, [Reference('device')] * len(job_args)) | |
start_time = time.time() | |
while not async_result.ready(): | |
time.sleep(status_interval) | |
n_finished = async_result.progress | |
n_jobs = len(job_args) | |
wall_time = start_time - time.time() | |
logger.info("%s seconds elapsed. %s of %s jobs finished", | |
wall_time, n_finished, n_jobs) | |
logger.info("All jobs finished in %s seconds!", async_result.wall_time) | |
def setup_logging(log_name, log_path): | |
""" Sets up module level logging | |
""" | |
# define module level logger | |
logger = logging.getLogger(log_name) | |
logger.setLevel(logging.DEBUG) | |
log_path = os.path.expanduser(log_path) | |
# define file handler for module | |
fh = logging.FileHandler(log_path) | |
fh.setLevel(logging.DEBUG) | |
# create formatter and add to handler | |
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
fh.setFormatter(formatter) | |
# add handler to logger | |
logger.addHandler(fh) | |
return logger |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment