Skip to content

Instantly share code, notes, and snippets.

@cheeyeo
Last active May 1, 2025 15:47
Show Gist options
  • Save cheeyeo/93cab948789a5285064f146c781bc5a0 to your computer and use it in GitHub Desktop.
Save cheeyeo/93cab948789a5285064f146c781bc5a0 to your computer and use it in GitHub Desktop.
Python script to upload Triton model to Sagemaker
#!/usr/bin/env python
import os
import json
import time
import boto3
import sagemaker
if __name__ == "__main__":
role = "SAGEMAKER EXECUTION ROLE ARN"
sm_client = boto3.client(service_name="sagemaker")
runtime_sm_client = boto3.client("sagemaker-runtime")
sagemaker_session = sagemaker.Session(boto_session=boto3.Session())
bucket = sagemaker.Session().default_bucket()
# account mapping for SageMaker MME Triton Image
account_id_map = {
"us-east-1": "785573368785",
"us-east-2": "007439368137",
"us-west-1": "710691900526",
"us-west-2": "301217895009",
"eu-west-1": "802834080501",
"eu-west-2": "205493899709",
"eu-west-3": "254080097072",
"eu-north-1": "601324751636",
"eu-south-1": "966458181534",
"eu-central-1": "746233611703",
"ap-east-1": "110948597952",
"ap-south-1": "763008648453",
"ap-northeast-1": "941853720454",
"ap-northeast-2": "151534178276",
"ap-southeast-1": "324986816169",
"ap-southeast-2": "355873309152",
"cn-northwest-1": "474822919863",
"cn-north-1": "472730292857",
"sa-east-1": "756306329178",
"ca-central-1": "464438896020",
"me-south-1": "836785723513",
"af-south-1": "774647643957",
}
region = boto3.Session().region_name
if region not in account_id_map.keys():
raise("Unsupported region")
base = "amazonaws.com.cn" if region.startswith("cn-") else "amazonaws.com"
account_id = account_id_map[region]
image_uri = f"{account_id}.dkr.ecr.{region}.{base}/sagemaker-tritonserver:23.12-py3"
# uploads to default sagemaker bucket i.e. sagemaker-eu-west-2-<ACCOUNT ID>
model_data_uri = sagemaker_session.upload_data(
path="ensemble_model.tar.gz",
key_prefix="ensemble_model"
)
container = {
"Image": image_uri,
"ModelDataUrl": model_data_uri,
"Environment": {"SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "ensemble_model"}
}
ts = time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())
sm_model_name = f"ensemble-{ts}"
create_model_response = sm_client.create_model(
ModelName=sm_model_name,
ExecutionRoleArn=role,
PrimaryContainer=container
)
print(f"Model Arn: {create_model_response["ModelArn"]}")
endpoint_config_name = f"ensemble-epc-{ts}-2xl"
create_endpoint_config_response = sm_client.create_endpoint_config(
EndpointConfigName=endpoint_config_name,
ProductionVariants=[
{
"InstanceType": "ml.g5.2xlarge",
"InitialVariantWeight": 1,
"InitialInstanceCount": 1,
"ModelName": sm_model_name,
"VariantName": "AllTraffic",
}
],
)
print("Endpoint Config Arn: " + create_endpoint_config_response["EndpointConfigArn"])
endpoint_name = f"ensemble-ep-{ts}-2xl"
create_endpoint_response = sm_client.create_endpoint(
EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)
print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])
resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)
while status == "Creating":
time.sleep(60)
resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)
print("Arn: " + resp["EndpointArn"])
print("Status: " + status)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment