from imjoy_rpc.hypha import connect_to_server import numpy as np import time SERVER_URL = 'http://127.0.0.1:9520' # "https://ai.imjoy.io" async def test_backbone(triton): config = await triton.get_config(model_name="micro-sam-vit-b-backbone") print(config) image = np.random.randint(0, 255, size=(1, 3, 1024, 1024), dtype=np.uint8).astype( "float32" ) start_time = time.time() result = await triton.execute( inputs=[image], model_name="micro-sam-vit-b-backbone", ) print("Backbone",result) embedding = result['output0__0'] print("Time taken: ", time.time() - start_time) print("Test passed", embedding.shape) async def test_decoder(triton): start_time = time.time() config = await triton.get_config(model_name="micro-sam-vit-b-decoder") print("Decoder", config) # {'name': 'orig_im_size', 'dims': [2]} # {'name': 'has_mask_input', 'dims': [1]} # {'name': 'mask_input', 'dims': [1, 1, 256, 256]} # {'name': 'point_labels', 'dims': [1, -1]} # {'name': 'point_coords', 'dims': [1, -1, 2]} # {'name': 'image_embeddings', 'dims': [1, 256, 64, 64]} orig_im_size = np.array([1024, 1024], dtype=np.float32) has_mask_input = np.array([0], dtype=np.float32) mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) point_labels = np.array([[0, 1, 2]], dtype=np.float32) point_coords = np.array([[[100, 200], [300, 400], [500, 600]]], dtype=np.float32) image_embeddings = np.random.rand(1, 256, 64, 64).astype(np.float32) result = await triton.execute( inputs=[orig_im_size, has_mask_input, mask_input, point_labels, point_coords, image_embeddings], model_name="micro-sam-vit-b-decoder", ) # the output keys are ['iou_predictions', 'low_res_masks', 'masks', '__info__'] print(result) print("Time taken: ", time.time() - start_time) print("Test passed", result['masks'].shape) async def run(): server = await connect_to_server( {"name": "test client", "server_url": SERVER_URL, "method_timeout": 100} ) triton = await server.get_service("triton-client") await test_backbone(triton) await test_decoder(triton) if __name__ == "__main__": import asyncio asyncio.run(run())