from imjoy_rpc.hypha import connect_to_server
import numpy as np
import time

SERVER_URL = 'http://127.0.0.1:9520' # "https://ai.imjoy.io"


async def test_backbone(triton):
    config = await triton.get_config(model_name="micro-sam-vit-b-backbone")
    print(config)
    
    image = np.random.randint(0, 255, size=(1, 3, 1024, 1024), dtype=np.uint8).astype(
        "float32"
    )

    start_time = time.time()
    result = await triton.execute(
        inputs=[image],
        model_name="micro-sam-vit-b-backbone",
    )
    print("Backbone",result)
    embedding = result['output0__0']
    print("Time taken: ", time.time() - start_time)
    print("Test passed", embedding.shape)


async def test_decoder(triton):
    start_time = time.time()
    config = await triton.get_config(model_name="micro-sam-vit-b-decoder")
    print("Decoder", config)
    # {'name': 'orig_im_size', 'dims': [2]}
    # {'name': 'has_mask_input', 'dims': [1]}
    # {'name': 'mask_input', 'dims': [1, 1, 256, 256]}
    # {'name': 'point_labels', 'dims': [1, -1]}
    # {'name': 'point_coords', 'dims': [1, -1, 2]}
    # {'name': 'image_embeddings', 'dims': [1, 256, 64, 64]}
    orig_im_size = np.array([1024, 1024], dtype=np.float32)
    has_mask_input = np.array([0], dtype=np.float32)
    mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
    point_labels = np.array([[0, 1, 2]], dtype=np.float32)
    point_coords = np.array([[[100, 200], [300, 400], [500, 600]]], dtype=np.float32)
    image_embeddings = np.random.rand(1, 256, 64, 64).astype(np.float32)
    
    result = await triton.execute(
        inputs=[orig_im_size, has_mask_input, mask_input, point_labels, point_coords, image_embeddings],
        model_name="micro-sam-vit-b-decoder",
    )
    # the output keys are ['iou_predictions', 'low_res_masks', 'masks', '__info__']
    print(result)
    print("Time taken: ", time.time() - start_time)
    print("Test passed", result['masks'].shape)
    

async def run():
    server = await connect_to_server(
        {"name": "test client", "server_url": SERVER_URL, "method_timeout": 100}
    )
    triton = await server.get_service("triton-client")
    await test_backbone(triton)
    await test_decoder(triton)

if __name__ == "__main__":
    import asyncio
    asyncio.run(run())