Created
April 17, 2025 00:34
-
-
Save yunho-c/2938846cfb314d53f9ca4a6c0a7666bb to your computer and use it in GitHub Desktop.
Compare (diff) two HuggingFace Transformer Models (specifically: SpatialLM vs. base)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Base Models (w/ sizes): | |
meta-llama/Llama-3.2-1B | |
Qwen/Qwen2.5-0.5B | |
SpatialLM Models: | |
manycore-research/SpatialLM-Llama-1B | |
manycore-research/SpatialLM-Qwen-0.5B | |
Commands: | |
python compare_models.py --base-model Qwen/Qwen2.5-0.5B --modified-model manycore-research/SpatialLM-Qwen-0.5B | |
python compare_models.py --base-model meta-llama/Llama-3.2-1B --modified-model manycore-research/SpatialLM-Llama-1B | |
""" | |
""" | |
compare_models.py | |
Usage: | |
python compare_models.py \ | |
--base-model <huggingface_id> \ | |
--modified-model <huggingface_id> \ | |
[--output csv] \ | |
[--device cpu] | |
This script: | |
1) Loads both models (via HuggingFace Transformers). | |
2) Extracts their state_dicts. | |
3) Reports: | |
- Keys only in base or only in SpatialLM. | |
- For shared keys: shape mismatches. | |
- For matching shapes: per-key L2 & max‐abs difference. | |
4) Optionally writes a CSV summary. | |
""" | |
import argparse | |
import torch | |
import numpy as np | |
from transformers import AutoModel, AutoConfig, AutoModelForCausalLM | |
from spatiallm import SpatialLMLlamaForCausalLM, SpatialLMQwenForCausalLM | |
def load_model(path_or_name, device): | |
# load model with config | |
# config = AutoConfig.from_pretrained(path_or_name, trust_remote_code=True) | |
# model = AutoModel.from_pretrained( | |
# path_or_name, config=config, trust_remote_code=True | |
# ) | |
# model = AutoModel.from_pretrained(path_or_name, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained(path_or_name, trust_remote_code=True) | |
return model.to(device) | |
def compare_state_dicts(sd_base, sd_target): | |
keys_base = set(sd_base.keys()) | |
keys_target = set(sd_target.keys()) | |
only_base = sorted(keys_base - keys_target) | |
only_target = sorted(keys_target - keys_base) | |
shared = sorted(keys_base & keys_target) | |
diffs = [] | |
for k in shared: | |
t1 = sd_base[k] | |
t2 = sd_target[k] | |
if t1.shape != t2.shape: | |
diffs.append( | |
{ | |
"key": k, | |
"status": "shape_mismatch", | |
"shape_base": t1.shape, | |
"shape_target": t2.shape, | |
"l2": None, | |
"max_abs": None, | |
} | |
) | |
else: | |
delta = (t2.float() - t1.float()).cpu().numpy() | |
l2 = np.linalg.norm(delta.ravel()) | |
max_abs = float(np.max(np.abs(delta))) | |
diffs.append( | |
{ | |
"key": k, | |
"status": "ok", | |
"shape_base": t1.shape, | |
"shape_target": t2.shape, | |
"l2": l2, | |
"max_abs": max_abs, | |
} | |
) | |
return only_base, only_target, diffs | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--base-model", | |
required=True, | |
help="HuggingFace name or local dir of the base model", | |
) | |
parser.add_argument( | |
"--modified-model", | |
required=True, | |
help="HuggingFace name or local dir of the SpatialLM variant", | |
) | |
parser.add_argument( | |
"--device", default="cpu", help="torch device to load models on (cpu or cuda)" | |
) | |
parser.add_argument( | |
"--output", | |
choices=["csv", "none"], | |
default="none", | |
help="if csv, dump detailed diff to diffs.csv", | |
) | |
args = parser.parse_args() | |
device = torch.device(args.device) | |
print(f"Loading base model '{args.base_model}' on {device}...") | |
model_base = load_model(args.base_model, device) | |
print(f"Loading SpatialLM model '{args.modified_model}' on {device}...") | |
model_target = load_model(args.modified_model, device) | |
sd_base = model_base.state_dict() | |
sd_target = model_target.state_dict() | |
only_base, only_target, diffs = compare_state_dicts(sd_base, sd_target) | |
print("\nParameters only in base model:") | |
for k in only_base: | |
print(" -", k) | |
print("\nParameters only in SpatialLM model:") | |
for k in only_target: | |
print(" -", k) | |
print("\nShared parameters summary:") | |
total_shared = len(diffs) | |
shape_mismatches = [d for d in diffs if d["status"] == "shape_mismatch"] | |
stats = [d for d in diffs if d["status"] == "ok"] | |
print(f" Shared total : {total_shared}") | |
print(f" Shape mismatches : {len(shape_mismatches)}") | |
print(f" Numeric diffs : {len(stats)}") | |
if stats: | |
l2s = [d["l2"] for d in stats] | |
maxabs = [d["max_abs"] for d in stats] | |
print( | |
f" L2 norm (shared): min={min(l2s):.4f}, max={max(l2s):.4f}, mean={sum(l2s) / len(l2s):.4f}" | |
) | |
print( | |
f" Max-abs (shared): min={min(maxabs):.4f}, max={max(maxabs):.4f}, mean={sum(maxabs) / len(maxabs):.4f}" | |
) | |
if args.output == "csv": | |
import csv | |
with open("diffs.csv", "w", newline="") as f: | |
writer = csv.DictWriter( | |
f, | |
fieldnames=[ | |
"key", | |
"status", | |
"shape_base", | |
"shape_target", | |
"l2", | |
"max_abs", | |
], | |
) | |
writer.writeheader() | |
for d in diffs: | |
writer.writerow(d) | |
print("Wrote detailed diffs to diffs.csv") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Sample Input/Output: