Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active April 23, 2026 05:40
Show Gist options
  • Select an option

  • Save wassname/733c568cd29c2a402be4442d6a061899 to your computer and use it in GitHub Desktop.

Select an option

Save wassname/733c568cd29c2a402be4442d6a061899 to your computer and use it in GitHub Desktop.
Guided CoT Evaluation: Hybrid Teacher-Forced + On-Policy Reasoning
"""Reusable guided-rollout primitive: think → forced-close-think → JSON choice.
One rollout, three numbers. The same primitive backs:
- calibrate()'s coherence + format + rep measurement
- probe replay at edit time
- post-keep probe regeneration
- (future) DD eval (once _measure_logratios is ported to this)
Substrate:
<user_prompt + schema_hint>
<think>
... model thinks up to max_think_tokens greedy ...
</think>
{"choice": ← model emits a schema value here
Critical: on gemma-3-4b, `</think>` is multi-token so early-stop-on-eos doesn't
fire. The model runs the full budget and often emits `</think>\\n{"choice": v}`
inline. We detect that string-wise AFTER generation and score at the natural
position (right after `{"choice": `), rather than blindly splicing a 2nd
suffix (which would score garbage).
Verbose mode dumps the raw decoded string with special tokens for format
debugging, and shows top-3 tokens at the scoring position whenever
pmass_format < 0.5 — i.e. whenever the model is not following the schema.
"""
from __future__ import annotations
import contextlib
from dataclasses import dataclass
import torch
import torch.nn.functional as F
from loguru import logger
from ssteer.core import _input_device
from ssteer.hooks import installed_svd, steer
_CLOSE_MARKER: str = "</think>"
_PREFILL: str = '\n{"choice": '
# Used only when the model never emitted `</think>` on its own.
_FORCE_SUFFIX: str = "\nI should answer now." + _CLOSE_MARKER + _PREFILL
@dataclass
class GuidedResult:
user_prompt: str
think_text: str # text before </think>
answer_text: str # model's continuation after `{"choice": ` (3-4 tok)
raw_full_text: str # whole decoded trace incl. specials (verbose debug)
pmass_format: float # P(all choice_token_ids) at answer position
logratio_ab: float # log P(a_ids) - log P(b_ids); NaN if no b_ids
rep_ratio_think: float # 4-gram distinct over think_text; NaN if <32 words
think_tokens: int # count of think tokens (pre-</think>)
emitted_close: bool # True if model emitted </think> string itself
emitted_prefill: bool # True if model emitted `{"choice": ` itself
_REP_MIN_TOKENS: int = 32
def _ngram_rep_ratio(text: str, n: int = 4) -> float:
tokens = text.split()
if len(tokens) < _REP_MIN_TOKENS:
return float("nan")
ngrams = [tuple(tokens[i:i + n]) for i in range(len(tokens) - n + 1)]
return len(set(ngrams)) / len(ngrams)
_DEFAULT_SCHEMA_HINT: str = (
"Think briefly, then answer immediately and only with: "
'{"choice": true} or {"choice": false}.'
)
@torch.no_grad()
def guided_rollout(
model, tok,
user_prompt: str,
choice_token_ids: list,
cvec=None,
coeff: float = 0.0,
max_think_tokens: int = 128,
answer_tokens: int = 4,
schema_hint: str = _DEFAULT_SCHEMA_HINT,
verbose: bool = False,
) -> GuidedResult:
"""Think → forced-close → JSON choice, all under one steering context.
`choice_token_ids` is either a flat list (pmass only, no logratio) or
`[a_variants, b_variants]` (pmass + logratio = logsumexp(a) - logsumexp(b)).
verbose=True prints the raw decoded trace (with special tokens) and, when
pmass_format < 0.5, the top-3 next-token candidates at the scoring
position — so you can see WHAT the model thought it should emit.
"""
device = _input_device(model)
full_user = f"{user_prompt}\n\n{schema_hint}" if schema_hint else user_prompt
messages = [{"role": "user", "content": full_user}]
prompt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
prompt = prompt + "<think>\n"
enc = tok(prompt, return_tensors="pt").to(device)
prompt_len = enc.input_ids.shape[1]
# Multi-token </think> on gemma-3 means this eos rarely fires; we handle
# the string-level case below.
think_end_id = tok.convert_tokens_to_ids("</think>")
if think_end_id in (None, getattr(tok, "unk_token_id", None)):
think_end_id = tok.eos_token_id
pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id
svd_ctx = installed_svd(model, cvec) if cvec is not None else contextlib.nullcontext()
steer_ctx = steer(model, cvec, coeff) if cvec is not None else contextlib.nullcontext()
with svd_ctx, steer_ctx:
# Phase 1: think, greedy.
phase1 = model.generate(
**enc,
max_new_tokens=max_think_tokens,
do_sample=False,
eos_token_id=think_end_id,
pad_token_id=pad_id,
)
gen_ids = phase1[0, prompt_len:]
keep = gen_ids != pad_id
gen_ids = gen_ids[keep] if keep.any() else gen_ids[:0]
gen_text = tok.decode(gen_ids, skip_special_tokens=True)
# String-level split. Handles both (a) model naturally emitted </think>,
# and (b) budget exhausted without closing.
emitted_close = _CLOSE_MARKER in gen_text
if emitted_close:
think_text, after = gen_text.split(_CLOSE_MARKER, 1)
if _PREFILL.lstrip() in after:
# Model also wrote the prefill. Align scoring prefix to its position.
emitted_prefill = True
before_value = after.split(_PREFILL.lstrip(), 1)[0]
scoring_text = prompt + think_text + _CLOSE_MARKER + before_value + _PREFILL.lstrip()
else:
emitted_prefill = False
scoring_text = prompt + think_text + _CLOSE_MARKER + _PREFILL
else:
think_text = gen_text
emitted_prefill = False
scoring_text = prompt + gen_text + _FORCE_SUFFIX
# Re-tokenize the scoring prefix — next-token at position -1 is the value.
score_ids = tok(scoring_text, return_tensors="pt",
add_special_tokens=False).input_ids.to(device)
# Phase 2a: score at scoring position.
logits = model(score_ids).logits[0, -1].float()
logp = F.log_softmax(logits, dim=-1)
if (len(choice_token_ids) == 2
and all(isinstance(x, (list, tuple)) for x in choice_token_ids)):
a_ids, b_ids = list(choice_token_ids[0]), list(choice_token_ids[1])
else:
a_ids, b_ids = list(choice_token_ids), []
all_ids = torch.tensor(a_ids + b_ids, device=device, dtype=torch.long)
pmass_format = float(logp[all_ids].exp().sum().item())
if a_ids and b_ids:
a_t = torch.tensor(a_ids, device=device, dtype=torch.long)
b_t = torch.tensor(b_ids, device=device, dtype=torch.long)
logratio = float(torch.logsumexp(logp[a_t], dim=0).item()
- torch.logsumexp(logp[b_t], dim=0).item())
else:
logratio = float("nan")
# Phase 2b: continue a few tokens for a readable answer_text.
cont = model.generate(
score_ids,
max_new_tokens=answer_tokens,
do_sample=False,
pad_token_id=pad_id,
)
answer_ids = cont[0, score_ids.shape[1]:]
answer_text = tok.decode(answer_ids, skip_special_tokens=True)
raw_full_text = tok.decode(cont[0], skip_special_tokens=False)
if verbose:
logger.info(
f"[guided_rollout verbose]\n"
f" scoring_text[-120:]: {scoring_text[-120:]!r}\n"
f" emitted_close={emitted_close} emitted_prefill={emitted_prefill}\n"
f" pmass_format={pmass_format:.3f} logratio={logratio:+.3f}\n"
f" answer_text={answer_text!r}\n"
f" === RAW (incl. specials) ===\n{raw_full_text}\n"
f" === END RAW ==="
)
if pmass_format < 0.5:
top_logp, top_idx = torch.topk(logp, 3)
tops = [(tok.decode([int(i)]), float(p.exp().item()))
for p, i in zip(top_logp, top_idx)]
logger.warning(
f"pmass<0.5 at scoring position. Top-3 tokens: {tops} "
f"→ schema broken. Adjust schema_hint or add variants "
f"to choice_token_ids."
)
return GuidedResult(
user_prompt=user_prompt,
think_text=think_text,
answer_text=answer_text,
raw_full_text=raw_full_text,
pmass_format=pmass_format,
logratio_ab=logratio,
rep_ratio_think=_ngram_rep_ratio(think_text, n=4),
think_tokens=int(score_ids.shape[1] - prompt_len),
emitted_close=emitted_close,
emitted_prefill=emitted_prefill,
)
def choice_token_ids_tf(tok) -> list[list[int]]:
"""[[true_variants...], [false_variants...]] — first-token after `{"choice": `."""
def _variants(words):
seen = []
for s in words:
tid = tok.encode(s, add_special_tokens=False)[-1]
if tid not in seen:
seen.append(tid)
return seen
return [_variants(["true", " true", "\ntrue", "True", " True", "\nTrue"]),
_variants(["false", " false", "\nfalse", "False", " False", "\nFalse"])]

Guided CoT eval

A trick for getting better eval signal from thinking models, with a fixed token budget.

The problem

Standard eval for concept ablation is teacher-forced: you feed the model a prefix like "My choice: **" and read the logprobs for Yes vs No. That's fast, but you only measure the effect on one token. The model never gets to reason under ablation, so you miss whether ablation actually changes the chain of thought.

Full on-policy generation (let the model write freely, parse the answer) captures everything but is slow, and parsing Yes/No from free text is fragile.

The trick

Let the model think for a bit, then force it to answer. Three steps:

  1. Generate a short reasoning trace (32 tokens) under ablation, greedy decoding
  2. Append a fixed suffix: \nI should answer now.\n</think>\nMy choice: **
  3. One forward pass on the whole sequence, read logprobs at the final position

The output is a logratio: log P(Yes) - log P(No), summed over tokenizer variants of "Yes"/"No" via logsumexp. You also get logratios, so you can compute calibrated uncertainties instead of just hard labels.

Here's what the token sequence looks like:

<|im_start|>assistant
<think>                              <-- chat template adds this
Thinking Process:
1. Analyze the Request:
   * Role: Main person...            <-- 32 tokens of on-policy reasoning
I should answer now.                  <-- appended
</think>
My choice: **                         <-- appended, score here

Pseudocode

def guided_eval(model, prompt, n_think=32):
    # prompt ends at "<think>\n" (from chat template)

    # ── 1. On-policy thinking (ablated) ──
    idsmodel.generate(prompt, max_new=n_think, greedy=True)

    # ── 2. Force transition to answer ──
    suffixtokenize("\nI should answer now.\n</think>\nMy choice: **")
    idscat([ids, suffix])

    # ── 3. Score final position ──
    model(ids).logits[:, -1, :]    # ℓ ∈ ℝ^V
    plog_softmax()

    p_yeslogsumexp(p[yes_ids])
    p_nologsumexp(p[no_ids])

    # pmass = exp(p_yes) + exp(p_no) should be > 0.5
    # if not, the model isn't predicting Yes/No

    return p_yes - p_no                 # logratio

When to use it

Teacher-forced is fine if you just want "does ablation flip the answer?" Guided CoT is better when you care about how ablation changes the reasoning path, because 32 tokens is enough for the chain of thought to diverge before scoring.

In practice, guided logratios correlate with teacher-forced (same sign, similar magnitude) but with more variance. That variance is from the reasoning trace, and it's informative.

I've used this across several projects and it gives better uncertainty estimates than teacher-forced, since you get proper logratios from a model that actually reasoned about the question.

Cost: 32 think tokens + 13 suffix tokens + 1 forward pass per item. For a 1360-item sweep at 8 prompts each, that's ~10K short generations instead of ~10K long ones for full on-policy.

Troubleshooting

pmass < 0.5

pmass is exp(p_yes) + exp(p_no), the total probability on Yes/No. If it's below 0.5, the model isn't confidently choosing either option after "My choice: **".

Things to check:

  • The </think> token must be the special token (ID 248069 for Qwen3.5), not the raw string. Tokenizers handle this differently.
  • If think_tokens is too low, the model hasn't finished a coherent thought and the forced suffix confuses it. Try bumping to 64.
  • The model needs to support <think> blocks in its chat template.

Incoherent thinking trace

  • Make sure the prompt ends at the chat template's generation point. For Qwen3.5 that's <|im_start|>assistant\n<think>\n. Don't add your own <think> tag on top.
  • Large ablation coefficients (|c| > 2) can make generation incoherent. Look at the thinking trace first when debugging.

OOM

model.generate() allocates KV cache, which uses more memory than a single forward pass. I use bs=4 for guided mode vs bs=16 for teacher-forced.

Logratios identical to teacher-forced

The trace is too short for reasoning to diverge. At 4 tokens there's barely any thinking, so the score converges to teacher-forced. Start at 32, try 64-128 for more divergence.

Design notes

Greedy decoding (do_sample=False) for the thinking trace so measurements are deterministic across runs.

The "I should answer now" suffix gives the model a natural transition into answering. Without it, a bare </think> appears abruptly and the model doesn't handle the context switch as cleanly.

The model doesn't generate its own </think> because it might never produce one, and a fixed token budget keeps runs comparable.

logsumexp over multiple Yes/No token IDs because tokenizers can encode "Yes" as "Yes", "yes", " Yes", etc. Summing captures the full decision mass.

@wassname
Copy link
Copy Markdown
Author

wassname commented Apr 5, 2026

I'll note

  • this works well and leads to much faster research
  • you get way less variation due to log probs and can eval on less rollouts
  • this seems to correlate well with a full eval, so I treat it as a quick proxy for dev, and also for more informative error bars.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment