Skip to content

Instantly share code, notes, and snippets.

@yunho-c
Created April 17, 2025 00:34
Show Gist options
  • Save yunho-c/2938846cfb314d53f9ca4a6c0a7666bb to your computer and use it in GitHub Desktop.
Save yunho-c/2938846cfb314d53f9ca4a6c0a7666bb to your computer and use it in GitHub Desktop.
Compare (diff) two HuggingFace Transformer Models (specifically: SpatialLM vs. base)
"""
Base Models (w/ sizes):
meta-llama/Llama-3.2-1B
Qwen/Qwen2.5-0.5B
SpatialLM Models:
manycore-research/SpatialLM-Llama-1B
manycore-research/SpatialLM-Qwen-0.5B
Commands:
python compare_models.py --base-model Qwen/Qwen2.5-0.5B --modified-model manycore-research/SpatialLM-Qwen-0.5B
python compare_models.py --base-model meta-llama/Llama-3.2-1B --modified-model manycore-research/SpatialLM-Llama-1B
"""
"""
compare_models.py
Usage:
python compare_models.py \
--base-model <huggingface_id> \
--modified-model <huggingface_id> \
[--output csv] \
[--device cpu]
This script:
1) Loads both models (via HuggingFace Transformers).
2) Extracts their state_dicts.
3) Reports:
- Keys only in base or only in SpatialLM.
- For shared keys: shape mismatches.
- For matching shapes: per-key L2 & max‐abs difference.
4) Optionally writes a CSV summary.
"""
import argparse
import torch
import numpy as np
from transformers import AutoModel, AutoConfig, AutoModelForCausalLM
from spatiallm import SpatialLMLlamaForCausalLM, SpatialLMQwenForCausalLM
def load_model(path_or_name, device):
# load model with config
# config = AutoConfig.from_pretrained(path_or_name, trust_remote_code=True)
# model = AutoModel.from_pretrained(
# path_or_name, config=config, trust_remote_code=True
# )
# model = AutoModel.from_pretrained(path_or_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(path_or_name, trust_remote_code=True)
return model.to(device)
def compare_state_dicts(sd_base, sd_target):
keys_base = set(sd_base.keys())
keys_target = set(sd_target.keys())
only_base = sorted(keys_base - keys_target)
only_target = sorted(keys_target - keys_base)
shared = sorted(keys_base & keys_target)
diffs = []
for k in shared:
t1 = sd_base[k]
t2 = sd_target[k]
if t1.shape != t2.shape:
diffs.append(
{
"key": k,
"status": "shape_mismatch",
"shape_base": t1.shape,
"shape_target": t2.shape,
"l2": None,
"max_abs": None,
}
)
else:
delta = (t2.float() - t1.float()).cpu().numpy()
l2 = np.linalg.norm(delta.ravel())
max_abs = float(np.max(np.abs(delta)))
diffs.append(
{
"key": k,
"status": "ok",
"shape_base": t1.shape,
"shape_target": t2.shape,
"l2": l2,
"max_abs": max_abs,
}
)
return only_base, only_target, diffs
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--base-model",
required=True,
help="HuggingFace name or local dir of the base model",
)
parser.add_argument(
"--modified-model",
required=True,
help="HuggingFace name or local dir of the SpatialLM variant",
)
parser.add_argument(
"--device", default="cpu", help="torch device to load models on (cpu or cuda)"
)
parser.add_argument(
"--output",
choices=["csv", "none"],
default="none",
help="if csv, dump detailed diff to diffs.csv",
)
args = parser.parse_args()
device = torch.device(args.device)
print(f"Loading base model '{args.base_model}' on {device}...")
model_base = load_model(args.base_model, device)
print(f"Loading SpatialLM model '{args.modified_model}' on {device}...")
model_target = load_model(args.modified_model, device)
sd_base = model_base.state_dict()
sd_target = model_target.state_dict()
only_base, only_target, diffs = compare_state_dicts(sd_base, sd_target)
print("\nParameters only in base model:")
for k in only_base:
print(" -", k)
print("\nParameters only in SpatialLM model:")
for k in only_target:
print(" -", k)
print("\nShared parameters summary:")
total_shared = len(diffs)
shape_mismatches = [d for d in diffs if d["status"] == "shape_mismatch"]
stats = [d for d in diffs if d["status"] == "ok"]
print(f" Shared total : {total_shared}")
print(f" Shape mismatches : {len(shape_mismatches)}")
print(f" Numeric diffs : {len(stats)}")
if stats:
l2s = [d["l2"] for d in stats]
maxabs = [d["max_abs"] for d in stats]
print(
f" L2 norm (shared): min={min(l2s):.4f}, max={max(l2s):.4f}, mean={sum(l2s) / len(l2s):.4f}"
)
print(
f" Max-abs (shared): min={min(maxabs):.4f}, max={max(maxabs):.4f}, mean={sum(maxabs) / len(maxabs):.4f}"
)
if args.output == "csv":
import csv
with open("diffs.csv", "w", newline="") as f:
writer = csv.DictWriter(
f,
fieldnames=[
"key",
"status",
"shape_base",
"shape_target",
"l2",
"max_abs",
],
)
writer.writeheader()
for d in diffs:
writer.writerow(d)
print("Wrote detailed diffs to diffs.csv")
if __name__ == "__main__":
main()
@yunho-c
Copy link
Author

yunho-c commented Apr 17, 2025

Sample Input/Output:

python compare_models.py --base-model Qwen/Qwen2.5-0.5B --modified-model manycore-research/SpatialLM-Qwen-0.5B
Loading base model 'Qwen/Qwen2.5-0.5B' on cpu...
generation_config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 138/138 [00:00<00:00, 1.82MB/s]
Loading SpatialLM model 'manycore-research/SpatialLM-Qwen-0.5B' on cpu...
model.safetensors: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2.02G/2.02G [00:18<00:00, 110MB/s]
generation_config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 121/121 [00:00<00:00, 2.73MB/s]

Parameters only in base model:

Parameters only in SpatialLM model:
  - point_backbone.extra_embedding
  - point_backbone.input_proj.bias
  - point_backbone.input_proj.weight
  - point_backbone.sparse_resnet.blocks.0.0.0.kernel
  - point_backbone.sparse_resnet.blocks.0.0.1.bias
  - point_backbone.sparse_resnet.blocks.0.0.1.weight
  - point_backbone.sparse_resnet.blocks.0.1.block0.0.kernel
  - point_backbone.sparse_resnet.blocks.0.1.block0.1.bias
  - point_backbone.sparse_resnet.blocks.0.1.block0.1.weight
  - point_backbone.sparse_resnet.blocks.0.1.block1.0.kernel
  - point_backbone.sparse_resnet.blocks.0.1.block1.1.bias
  - point_backbone.sparse_resnet.blocks.0.1.block1.1.weight
  - point_backbone.sparse_resnet.blocks.0.2.block0.0.kernel
  - point_backbone.sparse_resnet.blocks.0.2.block0.1.bias
  - point_backbone.sparse_resnet.blocks.0.2.block0.1.weight
  - point_backbone.sparse_resnet.blocks.0.2.block1.0.kernel
  - point_backbone.sparse_resnet.blocks.0.2.block1.1.bias
  - point_backbone.sparse_resnet.blocks.0.2.block1.1.weight
  - point_backbone.sparse_resnet.blocks.1.0.0.kernel
  - point_backbone.sparse_resnet.blocks.1.0.1.bias
  - point_backbone.sparse_resnet.blocks.1.0.1.weight
  - point_backbone.sparse_resnet.blocks.1.1.block0.0.kernel
  - point_backbone.sparse_resnet.blocks.1.1.block0.1.bias
  - point_backbone.sparse_resnet.blocks.1.1.block0.1.weight
  - point_backbone.sparse_resnet.blocks.1.1.block1.0.kernel
  - point_backbone.sparse_resnet.blocks.1.1.block1.1.bias
  - point_backbone.sparse_resnet.blocks.1.1.block1.1.weight
  - point_backbone.sparse_resnet.blocks.1.2.block0.0.kernel
  - point_backbone.sparse_resnet.blocks.1.2.block0.1.bias
  - point_backbone.sparse_resnet.blocks.1.2.block0.1.weight
  - point_backbone.sparse_resnet.blocks.1.2.block1.0.kernel
  - point_backbone.sparse_resnet.blocks.1.2.block1.1.bias
  - point_backbone.sparse_resnet.blocks.1.2.block1.1.weight
  - point_backbone.sparse_resnet.blocks.2.0.0.kernel
  - point_backbone.sparse_resnet.blocks.2.0.1.bias
  - point_backbone.sparse_resnet.blocks.2.0.1.weight
  - point_backbone.sparse_resnet.blocks.2.1.block0.0.kernel
  - point_backbone.sparse_resnet.blocks.2.1.block0.1.bias
  - point_backbone.sparse_resnet.blocks.2.1.block0.1.weight
  - point_backbone.sparse_resnet.blocks.2.1.block1.0.kernel
  - point_backbone.sparse_resnet.blocks.2.1.block1.1.bias
  - point_backbone.sparse_resnet.blocks.2.1.block1.1.weight
  - point_backbone.sparse_resnet.blocks.2.2.block0.0.kernel
  - point_backbone.sparse_resnet.blocks.2.2.block0.1.bias
  - point_backbone.sparse_resnet.blocks.2.2.block0.1.weight
  - point_backbone.sparse_resnet.blocks.2.2.block1.0.kernel
  - point_backbone.sparse_resnet.blocks.2.2.block1.1.bias
  - point_backbone.sparse_resnet.blocks.2.2.block1.1.weight
  - point_backbone.sparse_resnet.blocks.3.0.0.kernel
  - point_backbone.sparse_resnet.blocks.3.0.1.bias
  - point_backbone.sparse_resnet.blocks.3.0.1.weight
  - point_backbone.sparse_resnet.blocks.3.1.block0.0.kernel
  - point_backbone.sparse_resnet.blocks.3.1.block0.1.bias
  - point_backbone.sparse_resnet.blocks.3.1.block0.1.weight
  - point_backbone.sparse_resnet.blocks.3.1.block1.0.kernel
  - point_backbone.sparse_resnet.blocks.3.1.block1.1.bias
  - point_backbone.sparse_resnet.blocks.3.1.block1.1.weight
  - point_backbone.sparse_resnet.blocks.3.2.block0.0.kernel
  - point_backbone.sparse_resnet.blocks.3.2.block0.1.bias
  - point_backbone.sparse_resnet.blocks.3.2.block0.1.weight
  - point_backbone.sparse_resnet.blocks.3.2.block1.0.kernel
  - point_backbone.sparse_resnet.blocks.3.2.block1.1.bias
  - point_backbone.sparse_resnet.blocks.3.2.block1.1.weight
  - point_backbone.sparse_resnet.bottleneck.0.bias
  - point_backbone.sparse_resnet.bottleneck.0.weight
  - point_backbone.sparse_resnet.bottleneck.1.bias
  - point_backbone.sparse_resnet.bottleneck.1.weight
  - point_backbone.sparse_resnet.bottleneck.3.bias
  - point_backbone.sparse_resnet.bottleneck.3.weight
  - point_backbone.sparse_resnet.stem.0.0.kernel
  - point_backbone.sparse_resnet.stem.0.1.bias
  - point_backbone.sparse_resnet.stem.0.1.weight
  - point_backbone.sparse_resnet.stem.1.block0.0.kernel
  - point_backbone.sparse_resnet.stem.1.block0.1.bias
  - point_backbone.sparse_resnet.stem.1.block0.1.weight
  - point_backbone.sparse_resnet.stem.1.block1.0.kernel
  - point_backbone.sparse_resnet.stem.1.block1.1.bias
  - point_backbone.sparse_resnet.stem.1.block1.1.weight
  - point_proj.bias
  - point_proj.weight

Shared parameters summary:
  Shared total      : 291
  Shape mismatches  : 0
  Numeric diffs     : 291
  L2 norm   (shared): min=0.0097, max=48.7076, mean=3.8869
  Max-abs   (shared): min=0.0029, max=0.5024, mean=0.0422

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment