Skip to content

Instantly share code, notes, and snippets.

@beccajcarlson
Last active June 8, 2026 02:44
Show Gist options
  • Select an option

  • Save beccajcarlson/0104308c239fe49d91d82a1007632b27 to your computer and use it in GitHub Desktop.

Select an option

Save beccajcarlson/0104308c239fe49d91d82a1007632b27 to your computer and use it in GitHub Desktop.
Sonnet vs best DB vs >=k DB ensembles — overall accuracy on 147-gene bench

Sonnet vs best DB vs ≥k-DB ensembles — overall accuracy on 147-gene bench

Six-bar comparison of overall verdict accuracy on the 147-gene triage benchmark:

  • Sonnet (+ NCBI) — Claude Sonnet 4.6 triage agent on its canonical NCBI-resolver prompt variant
  • UniProt (TM+signal) — best single classical DB under its bench-optimized cutoff (TM OR signal-peptide positive)
  • ≥2 / ≥3 / ≥4 / ≥5 DB — ensemble callers: "yes" iff at least k of the 5 surface DBs (UniProt, GO CC, HPA, SURFY, CSPA, each under its own optimized cutoff) vote yes. k=2 is the lightest ensemble; k=5 is the strict intersection of all 5.

Accuracy uses the project's soft-credit rule: yes ≡ contextual on the positive side; no matches no only. Per-bucket breakdown (yes / contextual / no) lives in the sibling combined_db_correctness_by_class figure.

Run:

uv run make_ensemble_vs_best_db_vs_sonnet.py

Sources (fetched live from the public repo):

Canonical in-repo generator: scripts/ensemble_vs_best_db_vs_sonnet.py.

# /// script
# requires-python = ">=3.11"
# dependencies = [
# "matplotlib>=3.9",
# "pandas>=2.2",
# "seaborn>=0.13",
# "httpx>=0.27",
# ]
# ///
"""Reproduce ``ensemble_vs_best_db_vs_sonnet.{pdf,png}`` from the public repo.
Six-bar comparison on the 147-gene bench:
* Sonnet (+ IDs) — the Claude triage agent on its canonical
NCBI-resolver prompt variant
* UniProt (TM+signal) — best single classical DB, under its
bench-optimized cutoff (TM OR signal-peptide
positive)
* ≥2 / ≥3 / ≥4 / ≥5 DB — ensemble callers: "yes" iff at least k of
the 5 surface DBs (UniProt, GO CC, HPA,
SURFY, CSPA, each under its own optimized
cutoff) vote yes
Accuracy uses the project's soft-credit rule: yes ≡ contextual on
the positive side; ``no`` matches ``no`` only.
Visual styling matches the in-repo ``_plotting_config`` (Deliverome
categorical palette + Manrope-when-available). Inlined so the gist
runs standalone.
Standalone — ``uv run make_ensemble_vs_best_db_vs_sonnet.py``.
"""
from __future__ import annotations
import io
from pathlib import Path
import httpx
import matplotlib.font_manager as fm
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
REPO = "Deliverome-Project/accessible-surfaceome"
BRANCH = "main"
BASE = f"https://raw.githubusercontent.com/{REPO}/{BRANCH}"
REPO = "Deliverome-Project/accessible-surfaceome"
BRANCH = "main"
BASE = f"https://raw.githubusercontent.com/{REPO}/{BRANCH}"
BENCH_TSV = f"{BASE}/data/eval/triage_benchmark_v1.tsv"
CAND_TSV = f"{BASE}/data/processed/catalog/whole_proteome_catalog.tsv"
PREDS_TSV = f"{BASE}/data/processed/triage_bench/mainbench_canonical_v2.tsv"
# Per-replicate predictions — used to overlay the Sonnet bar's individual
# replicate accuracies + SEM (the DB / ensemble bars are deterministic and
# stay plain — no run-to-run variance to show).
REPS_TSV = f"{BASE}/data/processed/triage_bench/mainbench_replicates_v2.tsv"
OPT_CUTOFFS_TSV = f"{BASE}/data/processed/triage_bench/db_optimized_cutoffs.tsv"
# Published reproduction gist (embedded into output PNG Source / PDF
# Subject metadata — mirrors save_figure in _plotting_config.py).
GIST_URL = "https://gist.github.com/beccajcarlson/0104308c239fe49d91d82a1007632b27"
# ──── Inline brand styling — sentinel: brand-style-v3 ────
# Mirrors src/accessible_surfaceome/audit/_plotting_config.py so the gist
# stays self-contained. Kept in sync via tests/test_figure_gists_styling.py.
BRAND_PALETTE = [
"#BC3C4C", # maroon-light
"#3D6B60", # teal-mid
"#F4AA28", # amber-bright
"#8878C8", # lavender-bright
"#6E1428", # maroon-dark
"#7AAB9F", # teal-light
]
BRAND_SEQUENTIAL = {
"maroon": ["#3E0A18", "#6E1428", "#922038", "#BC3C4C", "#F0A098", "#FDE8E6"],
"teal": ["#152E28", "#244840", "#3D6B60", "#4D8A80", "#7AAB9F", "#CCE8E4"],
"amber": ["#5A2608", "#8C4210", "#C07830", "#F4AA28", "#F4C070", "#FAECD4"],
"lavender": ["#1E1450", "#3A2888", "#5848A8", "#8878C8", "#A090D4", "#E4E0F8"],
}
BRAND_CLAUDE_ORANGE = "#d87851"
BRAND_INK = "#1F1718"
BRAND_NEUTRAL = "#6F5D5A"
BRAND_GRID = "#E6DAD4"
def _register_brand_fonts() -> None:
candidates = [
Path(__file__).resolve().parents[3] / "assets" / "fonts",
Path.cwd() / "assets" / "fonts",
]
for fonts_dir in candidates:
if fonts_dir.is_dir():
for path in sorted(list(fonts_dir.glob("*.ttf")) + list(fonts_dir.glob("*.otf"))):
try:
fm.fontManager.addfont(str(path))
except Exception: # noqa: BLE001
continue
return
def _apply_brand_style() -> None:
"""Inline equivalent of `setup_plotting_style`. Sentinel: brand-style-v3.
v2: bumped sizes ~25% + explicit medium weight (avoids ExtraLight default
that matplotlib picks from the Manrope variable file). Companion to the
static Manrope-{regular,medium,semibold,bold}.otf files in assets/fonts/."""
_register_brand_fonts()
sns.set_style("whitegrid")
sns.set_context("notebook", font_scale=1.0)
plt.rcParams.update({
"savefig.dpi": 300,
"savefig.bbox": "tight",
"figure.facecolor": "none",
"savefig.facecolor": "none",
"font.family": "sans-serif",
"font.sans-serif": ["Manrope", "Outfit", "DejaVu Sans", "Liberation Sans", "Arial"],
"font.weight": "medium",
"font.size": 21,
"axes.labelsize": 25,
"axes.labelweight": "medium",
"axes.titlesize": 0,
"axes.titlepad": 0,
"axes.spines.top": False,
"axes.spines.right": False,
"axes.grid": True,
"axes.axisbelow": True,
"axes.edgecolor": BRAND_GRID,
"axes.labelcolor": BRAND_INK,
"axes.facecolor": "none",
"text.color": BRAND_INK,
"grid.alpha": 0.35,
"grid.linestyle": "-",
"grid.linewidth": 0.7,
"grid.color": BRAND_GRID,
"xtick.labelsize": 20,
"ytick.labelsize": 20,
"xtick.color": BRAND_INK,
"ytick.color": BRAND_INK,
"legend.frameon": False,
"legend.fontsize": 20,
"patch.edgecolor": "none",
"patch.linewidth": 0.0,
})
DB_LABELS = ["UniProt", "GO CC", "HPA", "SURFY", "CSPA"]
# Sequential teal ramp for the ≥k ensembles: light → dark = permissive → strict.
ENSEMBLE_PALETTE = {
2: "#7AAB9F", # teal-light
3: "#4D8A80",
4: "#3D6B60",
5: "#244840", # teal-deep
}
def _fetch_tsv(url: str) -> pd.DataFrame:
local = Path(__file__).resolve().parents[3] / url[len(BASE) + 1:]
if local.is_file():
return pd.read_csv(local, sep="\t")
r = httpx.get(url, timeout=30)
r.raise_for_status()
return pd.read_csv(io.StringIO(r.text), sep="\t")
def _vote_correct(vote: str, truth: str) -> bool:
if vote == truth:
return True
return vote in ("yes", "contextual") and truth in ("yes", "contextual")
def main() -> None:
_apply_brand_style()
bench = _fetch_tsv(BENCH_TSV)
cand = _fetch_tsv(CAND_TSV).set_index("uniprot_acc")
preds = _fetch_tsv(PREDS_TSV)
opt = _fetch_tsv(OPT_CUTOFFS_TSV)
uniprot_opt = set(opt.loc[opt["uniprot_optimized"] == 1, "accession"].astype(str))
cspa_opt = set(opt.loc[opt["cspa_optimized"] == 1, "accession"].astype(str))
truth_by_gene = dict(zip(bench["gene_symbol"], bench["ground_truth_verdict"], strict=True))
acc_by_gene = dict(zip(bench["gene_symbol"], bench["uniprot_acc"], strict=True))
def db_votes_for(acc: str) -> dict[str, bool]:
out = {label: False for label in DB_LABELS}
if not acc:
return out
out["UniProt"] = acc in uniprot_opt
out["CSPA"] = acc in cspa_opt
if acc in cand.index:
row = cand.loc[acc]
out["GO CC"] = row["go_surface_flag"] == 1
out["HPA"] = row["hpa_surface_flag"] == 1
out["SURFY"] = row["surfy_surface_flag"] == 1
return out
sonnet_ncbi = preds[
(preds["model"] == "claude-sonnet-4-6") & (preds["prompt_variant"] == "ncbi")
].set_index("gene_symbol")["predicted_verdict"].to_dict()
callers: list[tuple[str, callable, str]] = [
("Sonnet (+ IDs)", lambda g: sonnet_ncbi.get(g) or "no", BRAND_CLAUDE_ORANGE),
("UniProt\n(TM+signal)", lambda g: "yes" if db_votes_for(acc_by_gene.get(g, "")).get("UniProt") else "no",
BRAND_PALETTE[0]),
]
for k in (2, 3, 4, 5):
callers.append((
f"≥{k} DB",
(lambda g, k=k: "yes" if sum(db_votes_for(acc_by_gene.get(g, "")).values()) >= k else "no"),
ENSEMBLE_PALETTE[k],
))
# Per-replicate Sonnet accuracies (for the bar height = mean-of-reps,
# the points, and the SEM). The DB / ensemble callers are deterministic
# — single value, no replicates.
sonnet_rep_accs: list[float] = []
try:
_reps = _fetch_tsv(REPS_TSV)
_s = _reps[(_reps["model"] == "claude-sonnet-4-6")
& (_reps["prompt_variant"] == "ncbi")].copy()
_s["is_match"] = _s["is_match"].astype(int)
sonnet_rep_accs = _s.groupby("replicate")["is_match"].mean().tolist()
except Exception: # noqa: BLE001
sonnet_rep_accs = []
genes = list(truth_by_gene)
rows = []
for label, vote_fn, color in callers:
n_correct = sum(_vote_correct(vote_fn(g), truth_by_gene[g]) for g in genes)
acc = n_correct / len(genes)
# Sonnet bar = MEAN of per-replicate accuracies (so bar, points, and
# SEM share one center). DB / ensemble bars keep their deterministic
# majority count. The inner count label for Sonnet becomes the rounded
# mean count (≈ avg correct per single run), the others stay exact.
if label.startswith("Sonnet") and sonnet_rep_accs:
acc = sum(sonnet_rep_accs) / len(sonnet_rep_accs)
n_correct = round(acc * len(genes))
rows.append({
"caller": label, "accuracy": acc,
"n_correct": n_correct, "n_total": len(genes), "color": color,
})
df = pd.DataFrame(rows)
print("\nOverall accuracy on 147-gene bench:")
for _, row in df.iterrows():
print(f" {row['caller'].replace(chr(10), ' '):20s} "
f"{row['accuracy']*100:5.1f}% ({row['n_correct']}/{row['n_total']})")
fig, ax = plt.subplots(figsize=(10, 5.5))
bars = ax.bar(
df["caller"], df["accuracy"] * 100,
color=df["color"], edgecolor="none", width=0.6,
)
for bar, row in zip(bars, df.to_dict("records"), strict=True):
ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() + 1.2,
f"{row['accuracy']*100:.1f}%",
ha="center", va="bottom",
fontsize=21, fontweight="bold", color=BRAND_INK,
)
ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() / 2,
f"{row['n_correct']}/{row['n_total']}",
ha="center", va="center",
fontsize=15, color="white", fontweight="bold",
)
# Overlay individual-replicate accuracies + SEM on the Sonnet bar only.
# The Sonnet bar height is the MEAN of these per-rep accuracies (set
# above), so the error bar (centered on that same mean) and the points
# all line up with the bar top. DB / ensemble bars are deterministic and
# stay plain.
if len(sonnet_rep_accs) >= 2:
rep_accs = [v * 100 for v in sonnet_rep_accs]
sonnet_bar = bars[0]
xc = sonnet_bar.get_x() + sonnet_bar.get_width() / 2
m = sum(rep_accs) / len(rep_accs)
sd = (sum((v - m) ** 2 for v in rep_accs) / (len(rep_accs) - 1)) ** 0.5
sem = sd / (len(rep_accs) ** 0.5)
ax.errorbar(xc, m, yerr=sem, fmt="none", ecolor=BRAND_INK,
elinewidth=1.2, capsize=4, capthick=1.2, zorder=4)
for j, rv in enumerate(rep_accs):
jitter = (j - (len(rep_accs) - 1) / 2) * (sonnet_bar.get_width() * 0.16)
ax.scatter(xc + jitter, rv, s=24, color=BRAND_INK,
edgecolor="white", linewidth=0.5, zorder=5, alpha=0.9)
ax.set_ylabel("Overall accuracy on\n147-gene bench (%)", fontsize=19)
ax.set_ylim(0, 105)
ax.tick_params(axis="x", labelsize=17)
ax.tick_params(axis="y", labelsize=14)
sns.despine(ax=ax, top=True, right=True)
fig.tight_layout()
out_pdf = Path("ensemble_vs_best_db_vs_sonnet.pdf")
out_png = Path("ensemble_vs_best_db_vs_sonnet.png")
fig.savefig(out_pdf, bbox_inches="tight", metadata={"Subject": GIST_URL})
fig.savefig(out_png, bbox_inches="tight", dpi=300, metadata={"Source": GIST_URL})
print(f"Wrote {out_pdf} + {out_png}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment