Skip to content

Instantly share code, notes, and snippets.

@kstoneriv3
Last active May 4, 2025 01:21
Show Gist options
  • Save kstoneriv3/f7516fae5a4b8b0bfdd8d148886d02d7 to your computer and use it in GitHub Desktop.
Save kstoneriv3/f7516fae5a4b8b0bfdd8d148886d02d7 to your computer and use it in GitHub Desktop.
import os
import time
import numpy as np
import onnxruntime as rt
# ———————————————————————————————————————————————
# FORCE SINGLE-CORE FOR BLAS (if used under the hood)
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
# ———————————————————————————————————————————————
# SESSION OPTIONS
sess_opts = rt.SessionOptions()
sess_opts.intra_op_num_threads = 1 # one thread per op
sess_opts.inter_op_num_threads = 1 # one thread across ops
sess_opts.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_opts.enable_profiling = True # turn on built-in profiler
# ———————————————————————————————————————————————
# LOAD MODEL
model_path = "concat_tree_ensemble.onnx"
sess = rt.InferenceSession(model_path, sess_opts)
# grab the first input name (assumes single-input model)
input_name = sess.get_inputs()[0].name
# ———————————————————————————————————————————————
# SYNTHETIC INPUT
batch_size = 200
feature_dim = 20
X = np.random.rand(batch_size, feature_dim).astype(np.float32)
# ———————————————————————————————————————————————
# WARM-UP (not timed; fills caches, JIT, etc.)
for _ in range(10):
sess.run(None, {input_name: X})
# ———————————————————————————————————————————————
# BENCHMARK
num_iters = 1000
t0 = time.perf_counter()
for _ in range(num_iters):
sess.run(None, {input_name: X})
t1 = time.perf_counter()
total_sec = t1 - t0
avg_ms = total_sec / num_iters * 1e3
throughput = (batch_size * num_iters) / total_sec
print(f"Ran {num_iters} inferences of batch {batch_size}")
print(f"→ Total time: {total_sec:.3f} s")
print(f"→ Avg latency per batch: {avg_ms:.3f} ms")
print(f"→ Throughput: {throughput:.1f} samples/sec")
# ———————————————————————————————————————————————
# DUMP PROFILE TRACE
profile_file = sess.end_profiling()
print("ONNX Runtime profile file:", profile_file)
import json
import tempfile
from typing import Any, Dict, List, Tuple
import numpy as np
import xgboost as xgb
import onnx
from onnx import helper, TensorProto, shape_inference
def train_example(num_rounds: int = 100, max_depth: int = 6, n_feat: int | None = None) -> Tuple[xgb.Booster, np.ndarray, np.ndarray]:
"""
Train a simple multi-target XGBoost model for demonstration.
Returns the booster, input features X, and labels y.
"""
if n_feat:
beta = np.random.randn(n_feat, 4)
X_train = np.random.randn(200, n_feat).astype(np.float32)
y_train = X_train @ beta + np.random.randn(200, 4)
X_test = np.random.randn(200, n_feat).astype(np.float32)
y_test = X_test @ beta + np.random.randn(200, 4)
else:
X_train = np.random.randn(200, 3).astype(np.float32)
y_train = X_train * np.array([[0.2, 1.0, -0.2]], dtype=np.float32) + \
np.array([[2.0, -0.5, -2.0]], dtype=np.float32)
y_train += 0.2 * np.random.randn(200, 3).astype(np.float32)
X_test = np.random.randn(200, 3).astype(np.float32)
y_test = X_test * np.array([[0.2, 1.0, -0.2]], dtype=np.float32) + \
np.array([[2.0, 0.5, -2.0]], dtype=np.float32)
y_test += 0.2 * np.random.randn(200, 3).astype(np.float32)
dtrain = xgb.DMatrix(X_train, label=y_train)
params = {
"objective": "reg:squarederror",
"tree_method": "hist",
"multi_strategy": "one_output_per_tree",
"num_target": y_train.shape[1],
"max_depth": max_depth,
"eta": 0.2,
}
booster = xgb.train(params, dtrain, num_boost_round=num_rounds)
return booster, X_test, y_test
def dump_to_dict(bst: xgb.Booster) -> Tuple[List[Dict[str, Any]], int, int, float]:
"""
Save booster to a temporary JSON file and extract trees, number of targets,
number of features, and base_score.
"""
with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as tmp:
bst.save_model(tmp.name)
tmp_path = tmp.name
with open(tmp_path, 'r') as f:
data = json.load(f)
model_info = data["learner"]["gradient_booster"]["model"]
trees = model_info["trees"]
params = data["learner"]["learner_model_param"]
n_targets = int(params["num_target"])
n_features = int(params.get("num_feature", 0))
base_score = float(params["base_score"])
return trees, n_targets, n_features, base_score
class Collector:
"""
Collector for preparing attributes for ai.onnx.ml TreeEnsembleRegressor.
"""
def __init__(self):
# Both branch & leaf in one list, matching the ONNX example IR
self.nodes_treeids: List[int] = []
self.nodes_nodeids: List[int] = []
self.nodes_featureids: List[int] = []
self.nodes_modes: List[str] = []
self.nodes_values: List[float] = []
self.nodes_missing_value_tracks_true: List[int] = []
self.nodes_truenodeids: List[int] = []
self.nodes_falsenodeids: List[int] = []
# Then separate arrays for leaf contributions
self.target_treeids: List[int] = []
self.target_nodeids: List[int] = []
self.target_ids: List[int] = []
self.target_weights: List[float] = []
def collect(
self,
tree: Dict[str, Any],
size: int,
print_mapping: bool = False,
tree_id: int = 0,
) -> None:
lefts = tree["left_children"]
rights = tree["right_children"]
defaults = tree.get("default_left", [0] * len(lefts))
internal_nodes = [nid for nid in range(len(lefts)) if lefts[nid] >= 0 or rights[nid] >= 0]
leaf_nodes = [nid for nid in range(len(lefts)) if lefts[nid] < 0 and rights[nid] < 0]
# --- STUMP CASE: no splits ---
if not internal_nodes:
if print_mapping:
print(f"Tree {tree_id}: stump only, creating dummy split.")
stump_val = tree["split_conditions"][0]
# Branch node id = 0
self.nodes_treeids.append(tree_id)
self.nodes_nodeids.append(0)
self.nodes_featureids.append(0)
self.nodes_modes.append("BRANCH_LT")
self.nodes_values.append(stump_val)
self.nodes_missing_value_tracks_true.append(defaults[0])
# children → leaf IDs 1 & 2
self.nodes_truenodeids.append(1)
self.nodes_falsenodeids.append(2)
# Leaf entries in nodes_*
for leaf_id in (1, 2):
self.nodes_treeids.append(tree_id)
self.nodes_nodeids.append(leaf_id)
self.nodes_featureids.append(0)
self.nodes_modes.append("LEAF")
self.nodes_values.append(0.0)
self.nodes_missing_value_tracks_true.append(0)
self.nodes_truenodeids.append(0)
self.nodes_falsenodeids.append(0)
# And the target weights
for leaf_id in (1, 2):
self.target_treeids.append(tree_id)
self.target_nodeids.append(leaf_id)
self.target_ids.append(0)
self.target_weights.append(float(stump_val))
return
# --- NORMAL TREE CASE ---
# Map local node IDs
local_internal_map = {nid: idx for idx, nid in enumerate(internal_nodes)}
local_leaf_map = {nid: idx + len(internal_nodes) for idx, nid in enumerate(leaf_nodes)}
if print_mapping:
print(f"Tree {tree_id}: internals={internal_nodes}, leaves={leaf_nodes}")
# Emit branch nodes
for nid in internal_nodes:
lid, rid = lefts[nid], rights[nid]
local_id = local_internal_map[nid]
self.nodes_treeids.append(tree_id)
self.nodes_nodeids.append(local_id)
self.nodes_featureids.append(tree["split_indices"][nid])
self.nodes_modes.append("BRANCH_LT")
self.nodes_values.append(tree["split_conditions"][nid])
self.nodes_missing_value_tracks_true.append(defaults[nid])
# True child
self.nodes_truenodeids.append(
local_internal_map[lid] if lid in local_internal_map else local_leaf_map[lid]
)
# False child
self.nodes_falsenodeids.append(
local_internal_map[rid] if rid in local_internal_map else local_leaf_map[rid]
)
# Emit leaf entries (for the ONNX IR style)
for nid in leaf_nodes:
lid = local_leaf_map[nid]
self.nodes_treeids.append(tree_id)
self.nodes_nodeids.append(lid)
self.nodes_featureids.append(0)
self.nodes_modes.append("LEAF")
self.nodes_values.append(0.0)
self.nodes_missing_value_tracks_true.append(defaults[nid])
self.nodes_truenodeids.append(0)
self.nodes_falsenodeids.append(0)
# Finally, record the leaf weights via the target_* arrays
for nid in leaf_nodes:
leaf_id = local_leaf_map[nid]
vals = tree["base_weights"][nid * size:(nid + 1) * size]
for tgt_idx, wt in enumerate(vals):
self.target_treeids.append(tree_id)
self.target_nodeids.append(leaf_id)
self.target_ids.append(tgt_idx)
self.target_weights.append(float(wt))
def convert_multilabel_xgb_booster_to_onnx(
booster: xgb.Booster,
n_features: int,
print_tree_maps: bool = False,
) -> onnx.ModelProto:
trees, n_targets, _, base_score = dump_to_dict(booster)
rounds = len(trees) // n_targets
per_target = [[trees[i * n_targets + t] for i in range(rounds)] for t in range(n_targets)]
nodes: List[onnx.NodeProto] = []
outs: List[str] = []
for t_idx, t_trees in enumerate(per_target):
col = Collector()
for ridx, tree in enumerate(t_trees):
size = int(tree["tree_param"]["size_leaf_vector"])
col.collect(tree, size, print_mapping=print_tree_maps, tree_id=ridx)
name = f"Y_{t_idx}"
node = helper.make_node(
"TreeEnsembleRegressor",
inputs=["X"], outputs=[name], domain="ai.onnx.ml",
aggregate_function="SUM", base_values=[base_score],
n_targets=1, post_transform="NONE",
nodes_treeids=col.nodes_treeids,
nodes_nodeids=col.nodes_nodeids,
nodes_featureids=col.nodes_featureids,
nodes_modes=col.nodes_modes,
nodes_values=col.nodes_values,
nodes_missing_value_tracks_true=col.nodes_missing_value_tracks_true,
nodes_truenodeids=col.nodes_truenodeids,
nodes_falsenodeids=col.nodes_falsenodeids,
target_treeids=col.target_treeids,
target_nodeids=col.target_nodeids,
target_ids=col.target_ids,
target_weights=col.target_weights,
)
nodes.append(node)
outs.append(name)
# Concatenate outputs and add back base_score
# nodes.append(helper.make_node("Concat", inputs=outs, outputs=["Y_concat"], axis=1))
# base = helper.make_tensor("base_score", TensorProto.FLOAT, [n_targets], [base_score] * n_targets)
# nodes.append(helper.make_node("Add", inputs=["Y_concat", "base_score"], outputs=["Y"]))
# Concatenate outputs and add back base_score
nodes.append(helper.make_node("Concat", inputs=outs, outputs=["Y"], axis=1))
X_info = helper.make_tensor_value_info("X", TensorProto.FLOAT, [None, n_features])
Y_info = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [None, n_targets])
# graph = helper.make_graph(nodes, "xgb_multi_treeensemble_regressor", [X_info], [Y_info], initializer=[base])
graph = helper.make_graph(nodes, "xgb_multi_treeensemble_regressor", [X_info], [Y_info])
model = helper.make_model(
graph,
producer_name="xgb_to_onnx_multi_treeensemble_regressor",
opset_imports=[helper.make_opsetid("", 13), helper.make_opsetid("ai.onnx.ml", 3)],
)
return shape_inference.infer_shapes(model)
def main(
print_xgb_json: bool = False,
print_onnx_ir: bool = False,
print_diff_details: bool = False,
print_tree_maps: bool = False,
) -> None:
import onnxruntime as ort
booster, X, y = train_example(n_feat=20)
if print_xgb_json:
with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as tmp:
booster.save_model(tmp.name)
raw = json.loads(open(tmp.name).read())
print(json.dumps(raw["learner"]["gradient_booster"]["model"]["trees"][2::3], indent=2))
model = convert_multilabel_xgb_booster_to_onnx(
booster, n_features=X.shape[1], print_tree_maps=print_tree_maps
)
onnx.save_model(model, "concat_tree_ensemble.onnx")
if print_onnx_ir:
# print the third ensemble node in the ONNX graph
print(model.graph.node[2])
sess = ort.InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"])
pred_xgb = booster.predict(xgb.DMatrix(X))
pred_onnx = sess.run(None, {"X": X.astype(np.float32)})[0]
print("XGBoost pred[0:3]:", pred_xgb[:3])
print("ONNX pred[0:3]:", pred_onnx[:3])
print("Diff[0:3]:", np.abs(pred_xgb - pred_onnx)[:3])
print("Max diff:", np.abs(pred_xgb - pred_onnx).max())
if print_diff_details:
arg = np.argsort(np.abs(pred_xgb - pred_onnx).max(axis=1))
print("Argmax X:", X[arg[-5:]])
print("Argmax pred_xgb:", pred_xgb[arg[-5:]])
print("Argmax pred_onnx:", pred_onnx[arg[-5:]])
print("Argmax diff:", pred_xgb[arg[-5:]] - pred_onnx[arg[-5:]])
# If still off on target 2, dump debug
if np.abs(pred_xgb - pred_onnx).max(axis=0)[2] > 0.01:
trees, n_targets, _, base_score = dump_to_dict(booster)
with open("debug_log.txt", "w") as f:
f.write(
json.dumps(raw["learner"]["gradient_booster"]["model"]["trees"][2::3], indent=2)
+ "\n\n"
+ str(model.graph.node[2])
+ "\n\n"
+ f"base_score={base_score}\n"
+ "Argmax X: " + str(X[arg[-5:]]) + "\n"
+ "Argmax pred_xgb: " + str(pred_xgb[arg[-5:]]) + "\n"
+ "Argmax pred_onnx:" + str(pred_onnx[arg[-5:]]) + "\n"
+ "Argmax diff:" + str(pred_xgb[arg[-5:]] - pred_onnx[arg[-5:]])
)
breakpoint()
if __name__ == "__main__":
import tyro
tyro.cli(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment