Last active
May 11, 2022 19:05
-
-
Save wmayner/be88c0cb6f8fb3708e8126774e86ce95 to your computer and use it in GitHub Desktop.
Map a Ray remote function, returning early if a particular value is found
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
import time | |
import ray | |
ray.init() | |
def as_completed(object_refs, num_returns=1): | |
"""Yield remote results in order of completion.""" | |
unfinished = object_refs | |
while unfinished: | |
finished, unfinished = ray.wait(unfinished, num_returns=num_returns) | |
yield from ray.get(finished) | |
@functools.wraps(ray.cancel) | |
def cancel_all(object_refs, *args, **kwargs): | |
"""Cancel all remote tasks.""" | |
for ref in object_refs: | |
ray.cancel(ref, *args, **kwargs) | |
return object_refs | |
def shortcircuit( | |
items, | |
shortcircuit_value=None, | |
shortcircuit_callback=None, | |
shortcircuit_callback_args=None, | |
): | |
"""Yield from an iterable, stopping early if a certain value is found.""" | |
for result in items: | |
yield result | |
if result == shortcircuit_value: | |
if shortcircuit_callback: | |
shortcircuit_callback(shortcircuit_callback_args) | |
return | |
def map_remote_shortcircuit( | |
func, | |
*arg_lists, | |
shortcircuit_value=None, | |
shortcircuit_callback=cancel_all, | |
shortcircuit_callback_args=None, | |
**kwargs, | |
): | |
""" | |
Map a remote function to some arguments, returning early if a particular value is found. | |
By default, all remaining tasks are canceled. | |
""" | |
tasks = [func.remote(*args, **kwargs) for args in zip(*arg_lists)] | |
# Default to passing object_refs to the callback | |
if shortcircuit_callback_args is None: | |
shortcircuit_callback_args = tasks | |
return shortcircuit( | |
as_completed(tasks), | |
shortcircuit_value=shortcircuit_value, | |
shortcircuit_callback=shortcircuit_callback, | |
shortcircuit_callback_args=shortcircuit_callback_args, | |
) | |
# Test | |
@ray.remote | |
def f(x): | |
t = x | |
if not x: | |
t = 10 | |
time.sleep(t / 10) | |
return x | |
shortcircuit_value = 0 | |
args = list(range(1, 100)) + [shortcircuit_value] | |
print(list(map_remote_shortcircuit(f, args, shortcircuit_value=shortcircuit_value))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment