Last active
June 19, 2020 17:06
-
-
Save HoangTienDuc/a00a90cbed9fc67423aebdb438013ae3 to your computer and use it in GitHub Desktop.
- Run server: nvidia-docker run --rm --name trtserver -p 8000:8000 -p 8001:8001 -v `pwd`:/models nvcr.io/nvidia/tritonserver:20.03.1-py3 trtserver --model-store=/models --api-version=2
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Run server: nvidia-docker run --rm --name trtserver -p 8000:8000 -p 8001:8001 -v `pwd`:/models nvcr.io/nvidia/tritonserver:20.03.1-py3 trtserver --model-store=/models --api-version=2 | |
# Run client: nvidia-docker run -it -v `pwd`:/data --rm --net=host triton:20.03.1 | |
# Run server: nvidia-docker run --rm --name trtserver -p 8000:8000 -p 8001:8001 -v `pwd`:/models nvcr.io/nvidia/tritonserver:20.03.1-py3 trtserver --model-store=/models --api-version=2 | |
# Run client: nvidia-docker run -it -v `pwd`:/data --rm --net=host triton:20.03.1 | |
#!/usr/bin/env python | |
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |
# | |
# Redistribution and use in source and binary forms, with or without | |
# modification, are permitted provided that the following conditions | |
# are met: | |
# * Redistributions of source code must retain the above copyright | |
# notice, this list of conditions and the following disclaimer. | |
# * Redistributions in binary form must reproduce the above copyright | |
# notice, this list of conditions and the following disclaimer in the | |
# documentation and/or other materials provided with the distribution. | |
# * Neither the name of NVIDIA CORPORATION nor the names of its | |
# contributors may be used to endorse or promote products derived | |
# from this software without specific prior written permission. | |
# | |
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY | |
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR | |
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR | |
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, | |
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | |
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | |
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
import argparse | |
import numpy as np | |
import cv2 | |
import sys | |
from functools import partial | |
import os | |
import tritongrpcclient | |
import tritongrpcclient.model_config_pb2 as mc | |
import tritonhttpclient | |
from tritonclientutils.utils import triton_to_np_dtype | |
from tritonclientutils.utils import InferenceServerException | |
if sys.version_info >= (3, 0): | |
import queue | |
else: | |
import Queue as queue | |
class UserData: | |
def __init__(self): | |
self._completed_requests = queue.Queue() | |
# Callback function used for async_stream_infer() | |
def completion_callback(user_data, result, error): | |
# passing error raise and handling out | |
user_data._completed_requests.put((result, error)) | |
FLAGS = None | |
def parse_model_grpc(model_metadata, model_config): | |
input_metadatas = model_metadata.inputs | |
input_configs = model_config.input | |
output_metadatas = model_metadata.outputs | |
input_names = [input_metadata.name for input_metadata in input_metadatas] | |
output_names = [output_metadata.name for output_metadata in output_metadatas] | |
return (model_config.max_batch_size, input_names, | |
output_names) | |
def parse_model_http(model_metadata, model_config): | |
input_metadatas = model_metadata['inputs'] | |
input_config = model_config['input'] | |
output_metadatas = model_metadata['outputs'] | |
input_names = [input_metadata['name'] for input_metadata in input_metadatas] | |
output_names = [output_metadata['name'] for output_metadata in output_metadatas] | |
max_batch_size = 0 | |
if 'max_batch_size' in model_config: | |
max_batch_size = model_config['max_batch_size'] | |
return (max_batch_size, input_names, | |
output_names) | |
def preprocess(img_raw): | |
image_data = [] | |
img = np.float32(img_raw) | |
im_height, im_width, _ = img.shape | |
scale = [img.shape[1], img.shape[0], img.shape[1], img.shape[0]] | |
img -= (104, 117, 123) | |
img = img.transpose(2, 0, 1) | |
return scale, img | |
def postprocess(results, output_names, batch_size): | |
""" | |
Post-process results to show classifications. | |
""" | |
for output_name in output_names: | |
output_array = results.as_numpy(output_name) | |
print(output_name, "output_array: ", output_array) | |
if len(output_array) != batch_size: | |
raise Exception("expected {} results, got {}".format( | |
batch_size, len(output_array))) | |
# for results in output_array: | |
# for result in results: | |
# if output_array.dtype.type == np.bytes_: | |
# cls = "".join(chr(x) for x in result).split(':') | |
# else: | |
# cls = result.split(':') | |
# print(" {} = {}".format(cls[0], cls[1])) | |
def requestGenerator(batched_image_data, input_names, output_names, dtype, FLAGS): | |
# Set the input data | |
inputs = [] | |
if FLAGS.protocol.lower() == "grpc": | |
for input_name in input_names: | |
inputs.append( | |
tritongrpcclient.InferInput(input_name, batched_image_data.shape, | |
dtype)) | |
inputs[0].set_data_from_numpy(batched_image_data) | |
else: | |
for input_name in input_names: | |
inputs.append( | |
tritonhttpclient.InferInput(input_name, batched_image_data.shape, | |
dtype)) | |
inputs[0].set_data_from_numpy(batched_image_data, binary_data=False) | |
outputs = [] | |
if FLAGS.protocol.lower() == "grpc": | |
for output_name in output_names: | |
outputs.append( | |
tritongrpcclient.InferRequestedOutput(output_name, | |
class_count=FLAGS.classes)) | |
else: | |
for output_name in output_names: | |
outputs.append( | |
tritonhttpclient.InferRequestedOutput(output_name, | |
binary_data=False, | |
class_count=FLAGS.classes)) | |
yield inputs, outputs, FLAGS.model_name, FLAGS.model_version | |
def augments(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-v', | |
'--verbose', | |
action="store_true", | |
required=False, | |
default=False, | |
help='Enable verbose output') | |
parser.add_argument('-a', | |
'--async', | |
dest="async_set", | |
action="store_true", | |
required=False, | |
default=False, | |
help='Use asynchronous inference API') | |
parser.add_argument('--streaming', | |
action="store_true", | |
required=False, | |
default=False, | |
help='Use streaming inference API. ' + | |
'The flag is only available with gRPC protocol.') | |
parser.add_argument('-m', | |
'--model-name', | |
type=str, | |
required=True, | |
help='Name of model') | |
parser.add_argument( | |
'-x', | |
'--model-version', | |
type=str, | |
required=False, | |
default="", | |
help='Version of model. Default is to use latest version.') | |
parser.add_argument('-b', | |
'--batch-size', | |
type=int, | |
required=False, | |
default=1, | |
help='Batch size. Default is 1.') | |
parser.add_argument('-c', | |
'--classes', | |
type=int, | |
required=False, | |
default=1, | |
help='Number of class results to report. Default is 1.') | |
parser.add_argument( | |
'-s', | |
'--scaling', | |
type=str, | |
choices=['NONE', 'INCEPTION', 'VGG'], | |
required=False, | |
default='NONE', | |
help='Type of scaling to apply to image pixels. Default is NONE.') | |
parser.add_argument('-u', | |
'--url', | |
type=str, | |
required=False, | |
default='localhost:8000', | |
help='Inference server URL. Default is localhost:8000.') | |
parser.add_argument('-i', | |
'--protocol', | |
type=str, | |
required=False, | |
default='HTTP', | |
help='Protocol (HTTP/gRPC) used to communicate with ' + | |
'the inference service. Default is HTTP.') | |
parser.add_argument('image_filename', | |
type=str, | |
nargs='?', | |
default='1.jpg', | |
help='Input image / Input folder.') | |
return parser.parse_args() | |
def init_model(FLAGS): | |
if FLAGS.streaming and FLAGS.protocol.lower() != "grpc": | |
raise Exception("Streaming is only allowed with gRPC protocol") | |
try: | |
if FLAGS.protocol.lower() == "grpc": | |
# Create gRPC client for communicating with the server | |
triton_client = tritongrpcclient.InferenceServerClient( | |
url=FLAGS.url, verbose=FLAGS.verbose) | |
else: | |
# Create HTTP client for communicating with the server | |
triton_client = tritonhttpclient.InferenceServerClient( | |
url=FLAGS.url, verbose=FLAGS.verbose) | |
except Exception as e: | |
print("client creation failed: " + str(e)) | |
sys.exit(1) | |
# Make sure the model matches our requirements, and get some | |
# properties of the model that we need for preprocessing | |
try: | |
model_metadata = triton_client.get_model_metadata( | |
model_name=FLAGS.model_name, model_version=FLAGS.model_version) | |
except InferenceServerException as e: | |
print("failed to retrieve the metadata: " + str(e)) | |
sys.exit(1) | |
try: | |
model_config = triton_client.get_model_config( | |
model_name=FLAGS.model_name, model_version=FLAGS.model_version) | |
except InferenceServerException as e: | |
print("failed to retrieve the config: " + str(e)) | |
sys.exit(1) | |
if FLAGS.protocol.lower() == "grpc": | |
max_batch_size, input_name, output_name = parse_model_grpc( | |
model_metadata, model_config.config) | |
else: | |
max_batch_size, input_name, output_name = parse_model_http( | |
model_metadata, model_config) | |
return triton_client, max_batch_size, input_name, output_name | |
def inferencing(triton_client, batched_image_data, input_name, output_name, dtype, FLAGS, sent_count, responses): | |
try: | |
for inputs, outputs, model_name, model_version in requestGenerator( | |
batched_image_data, input_name, output_name, dtype, FLAGS): | |
sent_count += 1 | |
if FLAGS.streaming: | |
triton_client.async_stream_infer( | |
FLAGS.model_name, | |
inputs, | |
request_id=str(sent_count), | |
model_version=FLAGS.model_version, | |
outputs=outputs) | |
elif FLAGS.async_set: | |
if FLAGS.protocol.lower() == "grpc": | |
triton_client.async_infer( | |
FLAGS.model_name, | |
inputs, | |
partial(completion_callback, user_data), | |
request_id=str(sent_count), | |
model_version=FLAGS.model_version, | |
outputs=outputs) | |
else: | |
async_requests.append( | |
triton_client.async_infer( | |
FLAGS.model_name, | |
inputs, | |
request_id=str(sent_count), | |
model_version=FLAGS.model_version, | |
outputs=outputs)) | |
else: | |
responses.append( | |
triton_client.infer(FLAGS.model_name, | |
inputs, | |
request_id=str(sent_count), | |
model_version=FLAGS.model_version, | |
outputs=outputs)) | |
except InferenceServerException as e: | |
print("inference failed: " + str(e)) | |
if FLAGS.streaming: | |
triton_client.stop_stream() | |
sys.exit(1) | |
return responses | |
def processer(triton_client, FLAGS, image_data): | |
requests = [] | |
responses = [] | |
result_filenames = [] | |
request_ids = [] | |
image_idx = 0 | |
last_request = False | |
user_data = UserData() | |
# Holds the handles to the ongoing HTTP async requests. | |
async_requests = [] | |
sent_count = 0 | |
dtype = "FP32" | |
if FLAGS.streaming: | |
triton_client.start_stream( | |
partial(completion_callback, user_data)) | |
while not last_request: | |
input_filenames = [] | |
repeated_image_data = [] | |
for idx in range(FLAGS.batch_size): | |
input_filenames.append(idx) | |
repeated_image_data.append(image_data[image_idx]) | |
image_idx = (image_idx + 1) % len(image_data) | |
if image_idx == 0: | |
last_request = True | |
if max_batch_size > 0: | |
batched_image_data = np.stack(repeated_image_data, axis=0) | |
else: | |
batched_image_data = np.array(repeated_image_data) | |
# Send request | |
responses = inferencing(triton_client, batched_image_data, input_name, output_name, dtype, FLAGS, sent_count, responses) | |
if FLAGS.streaming: | |
triton_client.stop_stream() | |
if FLAGS.protocol.lower() == "grpc": | |
if FLAGS.streaming or FLAGS.async_set: | |
processed_count = 0 | |
while processed_count < sent_count: | |
(results, error) = user_data._completed_requests.get() | |
processed_count += 1 | |
if error is not None: | |
print("inference failed: " + str(error)) | |
sys.exit(1) | |
responses.append(results) | |
else: | |
if FLAGS.async_set: | |
# Collect results from the ongoing async requests | |
# for HTTP Async requests. | |
for async_request in async_requests: | |
responses.append(async_request.get_result()) | |
return responses | |
if __name__ == '__main__': | |
FLAGS = augments() | |
triton_client, max_batch_size, input_name, output_name = init_model(FLAGS) | |
img_path = './1.jpg' | |
img_raw = cv2.imread(img_path) | |
scale, img = preprocess(img_raw) | |
# Preprocess the images into input data according to model | |
# requirements | |
image_data = [] | |
image_data.append(img) | |
responses = processer(triton_client, FLAGS, image_data) | |
for response in responses: | |
if FLAGS.protocol.lower() == "grpc": | |
this_id = response.get_response().id | |
else: | |
this_id = response.get_response()["id"] | |
print("Request {}, batch size {}".format(this_id, FLAGS.batch_size)) | |
postprocess(response, output_name, FLAGS.batch_size) | |
print("PASS") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment