Skip to content

Instantly share code, notes, and snippets.

@beccajcarlson
Last active June 8, 2026 03:08
Show Gist options
  • Select an option

  • Save beccajcarlson/2bb4f7aac629535982c142bc2032e04d to your computer and use it in GitHub Desktop.

Select an option

Save beccajcarlson/2bb4f7aac629535982c142bc2032e04d to your computer and use it in GitHub Desktop.
Surface callers: bench correctness by verdict bucket (accessible-surfaceome)

DB correctness by class — optimized cutoffs vs Sonnet+NCBI on the 147-gene bench

5 surface DBs (UniProt, GO CC, HPA, SURFY, CSPA) + Sonnet+NCBI, grouped bars showing accuracy per ground-truth class (overall / yes / contextual / no). DB cutoffs are the trade-off-audit optimized versions, not canonical baselines:

  • UniProt — TM+signal: admit any accession with a TM domain, a signal peptide, OR a strict surface subcellular term (looser than canonical; rescues more bench positives without hurting the no-class).
  • CSPA — HC-only: admit only the high-confidence flag (drops putative + unspecific; stricter than canonical, lifts precision against the no-class).
  • GO CC / HPA / SURFY: canonical baselines (audit didn't surface a better cutoff).

See the companion db_cutoff_tradeoff figure for the audit that recommends these cutoffs.

Run:

uv run make_db_correctness_by_class.py

Sources (fetched live from the public API):

Canonical in-repo generator: scripts/triage_bench_db_barplot.py::make_by_class_plot with _USE_OPTIMIZED_CUTOFFS = True. The optimized accession TSV above is dumped as a side effect of the same function via _dump_db_optimized_cutoffs.

# /// script
# requires-python = ">=3.11"
# dependencies = [
# "matplotlib>=3.9",
# "pandas>=2.2",
# "seaborn>=0.13",
# "httpx>=0.27",
# ]
# ///
"""Reproduce ``db_correctness_by_class.{pdf,png}`` from the public repo.
For 6 callers (5 DBs + Sonnet+NCBI), grouped bars showing
overall accuracy + per-verdict-bucket accuracy (yes / contextual /
no) on the 147-gene bench.
DB cutoffs are the **trade-off-audit optimized** versions (see
``scripts/triage_bench_db_barplot.py::_optimized_uniprot_accs`` /
``_optimized_cspa_accs`` and the ``db_cutoff_tradeoff`` figure):
* **UniProt — TM+signal**: admit any accession with a transmembrane
domain, a signal peptide, OR a strict surface subcellular term
(looser than canonical; rescues more bench positives without
hurting the no-class).
* **CSPA — HC-only**: admit only the high-confidence flag (drops
``putative`` + ``unspecific`` rows; stricter than canonical, lifts
precision against the no-class).
* **GO CC / HPA / SURFY**: canonical baselines (audit didn't surface
a better cutoff).
Visual styling matches the in-repo `_plotting_config` (Deliverome
categorical palette + Manrope-when-available + whitegrid + despine +
transparent facecolor at 300 DPI). The styling block is inlined so the
gist runs standalone without depending on the project's plotting module.
Standalone — ``uv run make_db_correctness_by_class.py``.
"""
from __future__ import annotations
import io
from pathlib import Path
import httpx
import matplotlib.font_manager as fm
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
REPO = "Deliverome-Project/accessible-surfaceome"
BRANCH = "main"
BASE = f"https://raw.githubusercontent.com/{REPO}/{BRANCH}"
# Published reproduction gist (embedded into the output PNG's Source
# tEXt chunk + PDF's Subject info field — same pattern as the canonical
# save_figure helper in src/accessible_surfaceome/audit/_plotting_config.py
# so the figure carries its source URL even when dragged into a Substack
# draft or copied to Slack). Read back with `exiftool figure.png | grep Source`,
# or in Python: `from PIL import Image; Image.open(p).info["Source"]`.
GIST_URL = "https://gist.github.com/beccajcarlson/2bb4f7aac629535982c142bc2032e04d"
BENCH_TSV = f"{BASE}/data/eval/triage_benchmark_v1.tsv"
CAND_TSV = f"{BASE}/data/processed/catalog/whole_proteome_catalog.tsv"
PREDS_TSV = f"{BASE}/data/processed/triage_bench/mainbench_canonical_v2.tsv"
# Per-replicate predictions — overlays the Sonnet bars' individual-rep
# accuracy + SEM, per verdict bucket. DB bars are deterministic, no overlay.
REPS_TSV = f"{BASE}/data/processed/triage_bench/mainbench_replicates_v2.tsv"
# Optimized DB-cutoff accession TSV (one row per accession admitted
# by EITHER optimized rule; columns mark which). Dumped by the
# canonical generator's _dump_db_optimized_cutoffs when the by_class
# plot regenerates.
OPT_CUTOFFS = f"{BASE}/data/processed/triage_bench/db_optimized_cutoffs.tsv"
# ──── Inline brand styling — sentinel: brand-style-v3 ────
# Mirrors src/accessible_surfaceome/audit/_plotting_config.py so the gist
# stays self-contained (no in-repo imports — Substack readers run it
# standalone). Kept in sync via tests/test_figure_gists_styling.py.
BRAND_PALETTE = [
"#BC3C4C", # maroon-light
"#3D6B60", # teal-mid
"#F4AA28", # amber-bright
"#8878C8", # lavender-bright
"#6E1428", # maroon-dark
"#7AAB9F", # teal-light
]
BRAND_SEQUENTIAL = {
"maroon": ["#3E0A18", "#6E1428", "#922038", "#BC3C4C", "#F0A098", "#FDE8E6"],
"teal": ["#152E28", "#244840", "#3D6B60", "#4D8A80", "#7AAB9F", "#CCE8E4"],
"amber": ["#5A2608", "#8C4210", "#C07830", "#F4AA28", "#F4C070", "#FAECD4"],
"lavender": ["#1E1450", "#3A2888", "#5848A8", "#8878C8", "#A090D4", "#E4E0F8"],
}
BRAND_CLAUDE_ORANGE = "#d87851"
BRAND_INK = "#1F1718"
BRAND_NEUTRAL = "#6F5D5A"
BRAND_GRID = "#E6DAD4"
def _register_brand_fonts() -> None:
"""Register Manrope (and Playfair) from the repo's ``assets/fonts/``
when running inside a checkout. External readers without the repo
fall back to the next entry in ``font.sans-serif`` — typically
DejaVu Sans — without erroring."""
candidates = [
Path(__file__).resolve().parents[3] / "assets" / "fonts", # repo checkout
Path.cwd() / "assets" / "fonts", # cwd run
]
for fonts_dir in candidates:
if fonts_dir.is_dir():
for path in sorted(list(fonts_dir.glob("*.ttf")) + list(fonts_dir.glob("*.otf"))):
try:
fm.fontManager.addfont(str(path))
except Exception: # noqa: BLE001 — best-effort
continue
return
def _apply_brand_style() -> None:
"""Inline equivalent of `setup_plotting_style` — kept self-contained
so the gist runs without the in-repo plotting module. Sentinel:
brand-style-v3.
v2: bumped sizes ~25% + explicit medium weight (avoids ExtraLight default
that matplotlib picks from the Manrope variable file). Companion to the
static Manrope-{regular,medium,semibold,bold}.otf files in assets/fonts/."""
_register_brand_fonts()
sns.set_style("whitegrid")
sns.set_context("notebook", font_scale=1.0)
plt.rcParams.update({
"savefig.dpi": 300,
"savefig.bbox": "tight",
"figure.facecolor": "none",
"savefig.facecolor": "none",
"font.family": "sans-serif",
"font.sans-serif": ["Manrope", "Outfit", "DejaVu Sans", "Liberation Sans", "Arial"],
"font.weight": "medium",
"font.size": 21,
"axes.labelsize": 25,
"axes.labelweight": "medium",
"axes.titlesize": 0,
"axes.titlepad": 0,
"axes.spines.top": False,
"axes.spines.right": False,
"axes.grid": True,
"axes.axisbelow": True,
"axes.edgecolor": BRAND_GRID,
"axes.labelcolor": BRAND_INK,
"axes.facecolor": "none",
"text.color": BRAND_INK,
"grid.alpha": 0.35,
"grid.linestyle": "-",
"grid.linewidth": 0.7,
"grid.color": BRAND_GRID,
"xtick.labelsize": 20,
"ytick.labelsize": 20,
"xtick.color": BRAND_INK,
"ytick.color": BRAND_INK,
"legend.frameon": False,
"legend.fontsize": 20,
"patch.edgecolor": "none",
"patch.linewidth": 0.0,
})
DB_LABELS = ["UniProt", "GO CC", "HPA", "SURFY", "CSPA"]
DB_PALETTE = {
"UniProt": BRAND_PALETTE[0], # maroon-light
"GO CC": BRAND_PALETTE[1], # teal-mid
"HPA": BRAND_PALETTE[2], # amber-bright
"SURFY": BRAND_PALETTE[3], # lavender-bright
"CSPA": BRAND_PALETTE[4], # maroon-dark
"Sonnet (+ IDs)": BRAND_CLAUDE_ORANGE,
}
COLUMNS = ["overall", "yes", "contextual", "no"]
COLUMN_LABEL = {
"overall": "overall\n(all 147 proteins)",
"yes": "yes",
"contextual": "contextual\n(yes-vote = correct)",
"no": "no",
}
def _fetch_tsv(url: str) -> pd.DataFrame:
"""Fetch a TSV. Tries the local path first (so contributors with the
repo cloned can regenerate without hitting the network), then falls
back to the raw URL. Note: the raw URL only works once the repo is
public AND the file is LFS-exempted in .gitattributes — both are
pending for the eval / processed TSVs as of the 5-figure publish."""
if url.startswith(BASE + "/"):
local = Path(__file__).resolve().parents[3] / url[len(BASE) + 1:]
if local.is_file():
return pd.read_csv(local, sep="\t")
r = httpx.get(url, timeout=30)
r.raise_for_status()
return pd.read_csv(io.StringIO(r.text), sep="\t")
def _vote_correct(vote: str, truth: str) -> bool:
if vote == truth:
return True
return vote in ("yes", "contextual") and truth in ("yes", "contextual")
def main() -> None:
bench = _fetch_tsv(BENCH_TSV)
cand = _fetch_tsv(CAND_TSV).set_index("uniprot_acc")
preds = _fetch_tsv(PREDS_TSV)
opt = _fetch_tsv(OPT_CUTOFFS)
uniprot_opt = set(opt.loc[opt["uniprot_optimized"] == 1, "accession"].astype(str))
cspa_opt = set(opt.loc[opt["cspa_optimized"] == 1, "accession"].astype(str))
truth_by_gene = dict(zip(bench["gene_symbol"], bench["ground_truth_verdict"], strict=True))
acc_by_gene = dict(zip(bench["gene_symbol"], bench["uniprot_acc"], strict=True))
def _vote(gene: str, source: str) -> str:
acc = acc_by_gene.get(gene)
if not acc:
return "no"
if source == "UniProt":
return "yes" if acc in uniprot_opt else "no"
if source == "CSPA":
return "yes" if acc in cspa_opt else "no"
flag_col = {
"GO CC": "go_surface_flag",
"HPA": "hpa_surface_flag",
"SURFY": "surfy_surface_flag",
}[source]
if acc not in cand.index:
return "no"
return "yes" if cand.loc[acc, flag_col] == 1 else "no"
# Display order: LLM cell first, then DBs sorted by overall accuracy
# (descending) — matches the canonical generator's convention so the
# strongest source sits next to the LLM bar.
sonnet_ncbi = preds[
(preds["model"] == "claude-sonnet-4-6") & (preds["prompt_variant"] == "ncbi")
].set_index("gene_symbol")["predicted_verdict"].to_dict()
sonnet_label = "Sonnet (+ IDs)"
def _overall_acc(caller_label: str) -> float:
if caller_label == sonnet_label:
def vote_fn(g):
return sonnet_ncbi.get(g) or "no"
else:
def vote_fn(g, lbl=caller_label):
return _vote(g, lbl)
genes = list(truth_by_gene)
n_correct = sum(_vote_correct(vote_fn(g), truth_by_gene[g]) for g in genes)
return n_correct / len(genes)
db_labels_sorted = sorted(DB_LABELS, key=lambda lbl: -_overall_acc(lbl))
callers_in_plot = [sonnet_label, *db_labels_sorted]
# Sonnet per-bucket MEAN-of-replicate fraction — the Sonnet bar height in
# each bucket is the mean across the 3 replicates (not the majority-vote
# fraction), so the bar lines up with the overlaid points + SEM. DB callers
# are deterministic and keep their exact majority fraction.
sonnet_rep_frac: dict[str, float] = {}
try:
_reps = _fetch_tsv(REPS_TSV)
_s = _reps[(_reps["model"] == "claude-sonnet-4-6")
& (_reps["prompt_variant"] == "ncbi")].copy()
_s["is_match"] = _s["is_match"].astype(int)
_s["truth"] = _s["gene_symbol"].map(truth_by_gene)
_rep_ids = sorted(_s["replicate"].unique())
for bucket in COLUMNS:
sub = _s if bucket == "overall" else _s[_s["truth"] == bucket]
per_rep = [sub[sub["replicate"] == rid]["is_match"].mean()
for rid in _rep_ids if len(sub[sub["replicate"] == rid])]
if per_rep:
sonnet_rep_frac[bucket] = sum(per_rep) / len(per_rep)
except Exception: # noqa: BLE001
sonnet_rep_frac = {}
rows = []
for caller_label in callers_in_plot:
if caller_label == sonnet_label:
def vote_fn(g):
return sonnet_ncbi.get(g) or "no"
else:
def vote_fn(g, lbl=caller_label):
return _vote(g, lbl)
for bucket in COLUMNS:
genes = (
list(truth_by_gene)
if bucket == "overall"
else [g for g, t in truth_by_gene.items() if t == bucket]
)
if not genes:
continue
n_correct = sum(_vote_correct(vote_fn(g), truth_by_gene[g]) for g in genes)
frac = n_correct / len(genes)
# Sonnet bar = mean-of-reps fraction (when available).
if caller_label == sonnet_label and bucket in sonnet_rep_frac:
frac = sonnet_rep_frac[bucket]
rows.append({
"caller": caller_label,
"bucket": bucket,
"bucket_label": COLUMN_LABEL[bucket],
"n_correct": n_correct,
"n_total": len(genes),
"fraction": frac,
})
df = pd.DataFrame(rows)
_apply_brand_style()
overall = {row["caller"]: row["fraction"]
for row in rows if row["bucket"] == "overall"}
totals = {row["bucket"]: row["n_total"]
for row in rows if row["caller"] == sonnet_label and row["bucket"] != "overall"}
fig, ax = plt.subplots(figsize=(11, 5.5))
palette = [DB_PALETTE[c] for c in callers_in_plot]
sns.barplot(
data=df,
x="bucket_label", y="fraction",
hue="caller",
order=[COLUMN_LABEL[c] for c in COLUMNS],
hue_order=callers_in_plot,
palette=palette,
edgecolor="none", saturation=1.0,
ax=ax,
)
# Insert a small visible gap between the single LLM bar and the
# 5-DB cluster within each column group. Mirrors the canonical
# generator's layout (scripts/triage_bench_db_barplot.py).
n_col = len(COLUMNS)
n_callers = len(callers_in_plot)
bar_width = ax.patches[0].get_width()
gap = bar_width * 0.6
for caller_idx in range(1, n_callers):
for j in range(n_col):
patch = ax.patches[caller_idx * n_col + j]
patch.set_x(patch.get_x() + gap)
# Per-bar percentage annotations.
for i, caller in enumerate(callers_in_plot):
for j, bucket in enumerate(COLUMNS):
patch = ax.patches[i * n_col + j]
frac = df[(df["caller"] == caller) & (df["bucket"] == bucket)].iloc[0]["fraction"]
ax.text(
patch.get_x() + patch.get_width() / 2,
patch.get_height() + 0.01,
f"{frac:.0%}",
ha="center", va="bottom",
fontsize=11, color=BRAND_INK,
)
# Overlay individual-replicate accuracy + SEM on the Sonnet bars (caller
# index 0, one patch per bucket). DB callers are deterministic — no
# overlay. Per-bucket per-rep accuracy from the replicates TSV.
try:
reps = _fetch_tsv(REPS_TSV)
srep = reps[(reps["model"] == "claude-sonnet-4-6")
& (reps["prompt_variant"] == "ncbi")].copy()
srep["is_match"] = srep["is_match"].astype(int)
srep["truth"] = srep["gene_symbol"].map(truth_by_gene)
rep_ids = sorted(srep["replicate"].unique())
for j, bucket in enumerate(COLUMNS):
patch = ax.patches[0 * n_col + j] # Sonnet caller, this bucket
xc = patch.get_x() + patch.get_width() / 2
sub = srep if bucket == "overall" else srep[srep["truth"] == bucket]
accs = []
for rid in rep_ids:
cell = sub[sub["replicate"] == rid]
if len(cell):
accs.append(cell["is_match"].mean())
if len(accs) >= 2:
m = sum(accs) / len(accs)
sd = (sum((v - m) ** 2 for v in accs) / (len(accs) - 1)) ** 0.5
sem = sd / (len(accs) ** 0.5)
ax.errorbar(xc, m, yerr=sem, fmt="none", ecolor=BRAND_INK,
elinewidth=1.0, capsize=2.5, capthick=1.0, zorder=5)
for k, av in enumerate(accs):
jitter = (k - (len(accs) - 1) / 2) * (patch.get_width() * 0.22)
ax.scatter(xc + jitter, av, s=14, color=BRAND_INK,
edgecolor="white", linewidth=0.4, zorder=6, alpha=0.9)
except Exception: # noqa: BLE001
pass # best-effort overlay
ax.set_xlabel("")
ax.set_ylabel("Fraction correctly\nclassified")
ax.set_ylim(0, 1.14)
ax.yaxis.set_major_locator(plt.MaxNLocator(6))
handles, _ = ax.get_legend_handles_labels()
legend_labels = [f"{lbl} ({overall[lbl]:.0%})" for lbl in callers_in_plot]
ax.legend(
handles, legend_labels,
title="Caller (overall acc.)",
loc="upper left", bbox_to_anchor=(1.02, 1.0),
frameon=False, borderaxespad=0.0,
)
# n-per-class subtitle. Pushed below the x-tick labels (which are
# multi-line for the "overall" + "contextual" columns) so the subtitle
# doesn't collide with the "(all 147 proteins)" / "(yes-vote = correct)"
# second lines.
subtitle_parts = [f"n(overall) = {sum(totals.values())}"]
subtitle_parts += [f"n({v}) = {totals[v]}" for v in ["yes", "contextual", "no"]]
ax.text(
0.5, -0.34, " · ".join(subtitle_parts),
transform=ax.transAxes, ha="center", va="top",
fontsize=13, color=BRAND_NEUTRAL,
)
sns.despine(ax=ax, top=True, right=True)
out_pdf = Path("db_correctness_by_class.pdf")
out_png = Path("db_correctness_by_class.png")
fig.savefig(out_pdf, bbox_inches="tight", metadata={"Subject": GIST_URL})
fig.savefig(out_png, bbox_inches="tight", dpi=300, metadata={"Source": GIST_URL})
print(f"Wrote {out_pdf} + {out_png} ({len(df)} (caller, bucket) cells; "
f"UniProt TM+signal n={len(uniprot_opt):,}, CSPA HC-only n={len(cspa_opt):,})")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment