Skip to content

Instantly share code, notes, and snippets.

@zhangqiaorjc
Forked from mattjj/zhangqiaorjc_gpu_jax.py
Last active August 24, 2021 23:06
Show Gist options
  • Save zhangqiaorjc/4d574d0819acabf0609f8758b5b3fb5e to your computer and use it in GitHub Desktop.
Save zhangqiaorjc/4d574d0819acabf0609f8758b5b3fb5e to your computer and use it in GitHub Desktop.
# [email protected]
import atexit
import functools
from absl import app
from absl import flags
from absl import logging
import jax
from jax.lib import xla_extension as xc
flags.DEFINE_string('server_ip', '', help='server ip addr')
flags.DEFINE_integer('server_port', 0, help='server ip port')
flags.DEFINE_integer('num_hosts', 1, help='num of hosts' )
flags.DEFINE_integer('host_idx', 0, help='index of current host' )
FLAGS = flags.FLAGS
def connect_to_gpu_cluster():
service = None
if FLAGS.host_idx == 0:
addr = f'{FLAGS.server_ip}:{FLAGS.server_port}'
logging.info('starting service on %s', addr)
service = xc.get_distributed_runtime_service(addr, FLAGS.num_hosts)
# We add an explicit call to shutdown the service via atexit, since it seems
# Python interpreter may not call the service destructor on process
# termination.
atexit.register(service.shutdown)
server_addr = f'{FLAGS.server_ip}:{FLAGS.server_port}'
logging.info('connecting to service on %s', server_addr)
dist_client = xc.get_distributed_runtime_client(server_addr, FLAGS.host_idx)
dist_client.connect()
atexit.register(dist_client.shutdown)
# register dist gpu backend
factory = functools.partial(jax.lib.xla_client.make_gpu_client, dist_client, FLAGS.host_idx)
jax.lib.xla_bridge.register_backend_factory('gpu', factory, priority=300)
return service
def main(argv):
service = connect_to_gpu_cluster()
logging.info('gpu cluster connected')
logging.info('devices %s', jax.devices())
logging.info('local devices %s', jax.local_devices())
logging.info('shutting down gpu cluster...')
if __name__ == '__main__':
app.run(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment