-
-
Save zhangqiaorjc/4d574d0819acabf0609f8758b5b3fb5e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# [email protected] | |
import atexit | |
import functools | |
from absl import app | |
from absl import flags | |
from absl import logging | |
import jax | |
from jax.lib import xla_extension as xc | |
flags.DEFINE_string('server_ip', '', help='server ip addr') | |
flags.DEFINE_integer('server_port', 0, help='server ip port') | |
flags.DEFINE_integer('num_hosts', 1, help='num of hosts' ) | |
flags.DEFINE_integer('host_idx', 0, help='index of current host' ) | |
FLAGS = flags.FLAGS | |
def connect_to_gpu_cluster(): | |
service = None | |
if FLAGS.host_idx == 0: | |
addr = f'{FLAGS.server_ip}:{FLAGS.server_port}' | |
logging.info('starting service on %s', addr) | |
service = xc.get_distributed_runtime_service(addr, FLAGS.num_hosts) | |
# We add an explicit call to shutdown the service via atexit, since it seems | |
# Python interpreter may not call the service destructor on process | |
# termination. | |
atexit.register(service.shutdown) | |
server_addr = f'{FLAGS.server_ip}:{FLAGS.server_port}' | |
logging.info('connecting to service on %s', server_addr) | |
dist_client = xc.get_distributed_runtime_client(server_addr, FLAGS.host_idx) | |
dist_client.connect() | |
atexit.register(dist_client.shutdown) | |
# register dist gpu backend | |
factory = functools.partial(jax.lib.xla_client.make_gpu_client, dist_client, FLAGS.host_idx) | |
jax.lib.xla_bridge.register_backend_factory('gpu', factory, priority=300) | |
return service | |
def main(argv): | |
service = connect_to_gpu_cluster() | |
logging.info('gpu cluster connected') | |
logging.info('devices %s', jax.devices()) | |
logging.info('local devices %s', jax.local_devices()) | |
logging.info('shutting down gpu cluster...') | |
if __name__ == '__main__': | |
app.run(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment