Skip to content

Instantly share code, notes, and snippets.

@Mistobaan
Created March 21, 2026 01:07
Show Gist options
  • Select an option

  • Save Mistobaan/f3fe6538f34796d6fa5f501c08f97a49 to your computer and use it in GitHub Desktop.

Select an option

Save Mistobaan/f3fe6538f34796d6fa5f501c08f97a49 to your computer and use it in GitHub Desktop.
Vibe Coded Script to check the length of tokenized documents

OpenAI golf parameter

This is a vibe coded script by codex to check the average document length after tokenization on the pre-tokenized dataset of parameter golf.

Full train-split summary from the current cached dataset

documents: 6,292,940
tokens_accounted_for: 8,000,000,000
mean_len: 1271.266
p50: 747
p90: 2478
p95: 3646
p99: 8952
p99.9: 28812
max_len: 316748

Results for that val

documents: 50,000
tokens_accounted_for: 62,021,846
mean_len: 1240.437
p50: 733
p90: 2454
p95: 3592
p99: 8602
p99.9: 26560
max_len: 123565
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import glob
import math
import os
import sys
from dataclasses import dataclass
from pathlib import Path
import numpy as np
HEADER_INTS = 256
DATAFILE_MAGIC = 20240520
DATAFILE_VERSION = 1
DEFAULT_BOS_ID = 1
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Compute the document-length distribution for FineWeb shard sets. "
"This matches train_gpt.py's shard format and BOS-delimited document boundaries."
)
)
parser.add_argument(
"--data-path",
default=os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024"),
help="Dataset directory. Default: DATA_PATH or ./data/datasets/fineweb10B_sp1024",
)
parser.add_argument(
"--split",
choices=("train", "val", "test"),
default="train",
help=(
"Named shard set under --data-path. "
"'test' is an alias for the held-out fineweb_val_* shards in this repo."
),
)
parser.add_argument(
"--pattern",
default=None,
help="Optional shard glob override. Takes precedence over --split.",
)
parser.add_argument(
"--files",
nargs="+",
default=None,
help=(
"Explicit file paths and/or glob patterns to scan. "
"Takes precedence over --pattern and --split."
),
)
parser.add_argument(
"--bos-id",
type=int,
default=DEFAULT_BOS_ID,
help="Token id that marks the start of each document. Default: 1",
)
parser.add_argument(
"--exclude-bos",
action="store_true",
help="Report content length excluding the leading BOS token from each document.",
)
parser.add_argument(
"--limit-shards",
type=int,
default=0,
help="Only scan the first N matched shards. Useful for quick smoke tests.",
)
parser.add_argument(
"--progress-every",
type=int,
default=10,
help="Print progress every N shards to stderr. Use 0 to disable.",
)
parser.add_argument(
"--write-tsv",
type=Path,
default=None,
help="Optional path to write the exact per-length counts as TSV.",
)
return parser.parse_args()
def load_data_shard(file: Path) -> np.ndarray:
header_bytes = HEADER_INTS * np.dtype("<i4").itemsize
token_bytes = np.dtype("<u2").itemsize
header = np.fromfile(file, dtype="<i4", count=HEADER_INTS)
if (
header.size != HEADER_INTS
or int(header[0]) != DATAFILE_MAGIC
or int(header[1]) != DATAFILE_VERSION
):
raise ValueError(f"Unexpected shard header for {file}")
num_tokens = int(header[2])
expected_size = header_bytes + num_tokens * token_bytes
if file.stat().st_size != expected_size:
raise ValueError(f"Shard size mismatch for {file}: expected {expected_size} bytes")
tokens = np.fromfile(file, dtype="<u2", count=num_tokens, offset=header_bytes)
if tokens.size != num_tokens:
raise ValueError(f"Short read for {file}")
return tokens
@dataclass
class LengthStats:
hist: np.ndarray
total_docs: int = 0
total_tokens: int = 0
min_len: int | None = None
max_len: int = 0
def add(self, lengths: np.ndarray) -> None:
lengths = np.asarray(lengths, dtype=np.int64)
if lengths.size == 0:
return
if np.any(lengths < 0):
raise ValueError("Encountered negative document length after adjustment")
counts = np.bincount(lengths)
if counts.size > self.hist.size:
grown = np.zeros((counts.size,), dtype=np.int64)
grown[: self.hist.size] = self.hist
self.hist = grown
self.hist[: counts.size] += counts
self.total_docs += int(lengths.size)
self.total_tokens += int(lengths.sum(dtype=np.int64))
shard_min = int(lengths.min())
shard_max = int(lengths.max())
self.min_len = shard_min if self.min_len is None else min(self.min_len, shard_min)
self.max_len = max(self.max_len, shard_max)
@dataclass
class ScanResult:
stats: LengthStats
shards_scanned: int
total_shards_matched: int
total_tokens_seen: int
dropped_tail_tokens: int
file_selection: str
def maybe_adjust(lengths: np.ndarray, exclude_bos: bool) -> np.ndarray:
if not exclude_bos:
return lengths
return lengths - 1
def scan_doc_lengths(
files: list[Path],
bos_id: int,
exclude_bos: bool,
progress_every: int,
total_shards_matched: int,
file_selection: str,
drop_final_partial_doc: bool,
) -> ScanResult:
stats = LengthStats(hist=np.zeros((1,), dtype=np.int64))
current_doc_len: int | None = None
total_tokens_seen = 0
for shard_idx, file in enumerate(files, start=1):
tokens = load_data_shard(file)
total_tokens_seen += int(tokens.size)
bos_positions = np.flatnonzero(tokens == bos_id)
if bos_positions.size == 0:
if current_doc_len is None:
raise ValueError(
f"No BOS token found before or inside {file}. "
"Expected the matched shard set to begin at a document boundary."
)
current_doc_len += int(tokens.size)
else:
first_bos = int(bos_positions[0])
if current_doc_len is None:
if first_bos != 0:
raise ValueError(
f"{file} begins with {first_bos} tokens before the first BOS token. "
"Expected the matched shard set to begin at a document boundary."
)
else:
stats.add(maybe_adjust(np.asarray([current_doc_len + first_bos], dtype=np.int64), exclude_bos))
if bos_positions.size > 1:
stats.add(maybe_adjust(np.diff(bos_positions).astype(np.int64, copy=False), exclude_bos))
current_doc_len = int(tokens.size - int(bos_positions[-1]))
if progress_every and (shard_idx % progress_every == 0 or shard_idx == len(files)):
tail = 0 if current_doc_len is None else current_doc_len
print(
f"[progress] shards={shard_idx}/{len(files)} docs={stats.total_docs:,} "
f"tokens_seen={total_tokens_seen:,} open_doc_tokens={tail:,}",
file=sys.stderr,
flush=True,
)
dropped_tail_tokens = 0
if current_doc_len is not None:
if drop_final_partial_doc:
dropped_tail_tokens = current_doc_len
else:
stats.add(maybe_adjust(np.asarray([current_doc_len], dtype=np.int64), exclude_bos))
return ScanResult(
stats=stats,
shards_scanned=len(files),
total_shards_matched=total_shards_matched,
total_tokens_seen=total_tokens_seen,
dropped_tail_tokens=dropped_tail_tokens,
file_selection=file_selection,
)
def nearest_rank_quantile(hist: np.ndarray, total_docs: int, q: float) -> int:
if total_docs <= 0:
raise ValueError("Quantiles require at least one document")
rank = max(1, int(math.ceil(q * total_docs)))
cdf = np.cumsum(hist, dtype=np.int64)
return int(np.searchsorted(cdf, rank, side="left"))
def print_summary(result: ScanResult, exclude_bos: bool) -> None:
stats = result.stats
if stats.total_docs == 0:
raise ValueError("No complete documents were found")
quantiles = {
"p50": nearest_rank_quantile(stats.hist, stats.total_docs, 0.50),
"p90": nearest_rank_quantile(stats.hist, stats.total_docs, 0.90),
"p95": nearest_rank_quantile(stats.hist, stats.total_docs, 0.95),
"p99": nearest_rank_quantile(stats.hist, stats.total_docs, 0.99),
"p99.9": nearest_rank_quantile(stats.hist, stats.total_docs, 0.999),
}
length_mode = (
"content tokens per document (leading BOS excluded)"
if exclude_bos
else "stored tokens per document (leading BOS included)"
)
print(f"file_selection: {result.file_selection}")
print(f"length_mode: {length_mode}")
print(f"shards_scanned: {result.shards_scanned}/{result.total_shards_matched}")
print(f"documents: {stats.total_docs:,}")
print(f"tokens_accounted_for: {stats.total_tokens:,}")
print(f"tokens_seen_in_shards: {result.total_tokens_seen:,}")
if result.dropped_tail_tokens:
print(f"dropped_tail_tokens: {result.dropped_tail_tokens:,}")
print(f"min_len: {stats.min_len}")
print(f"mean_len: {stats.total_tokens / stats.total_docs:.3f}")
for name, value in quantiles.items():
print(f"{name}_len: {value}")
print(f"max_len: {stats.max_len}")
print()
print("log2_histogram:")
if stats.hist[0]:
count = int(stats.hist[0])
frac = 100.0 * count / stats.total_docs
print(f" [0, 0]: count={count:,} frac={frac:.3f}%")
low = 1
if stats.min_len is not None and stats.min_len > 1:
low = 1 << int(math.floor(math.log2(stats.min_len)))
while low <= stats.max_len:
high = min(stats.max_len, 2 * low - 1)
count = int(stats.hist[low : high + 1].sum(dtype=np.int64))
frac = 100.0 * count / stats.total_docs
if count:
print(f" [{low}, {high}]: count={count:,} frac={frac:.3f}%")
low *= 2
def write_tsv(path: Path, hist: np.ndarray, total_docs: int) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
nz = np.flatnonzero(hist)
cumulative = 0
with path.open("w", encoding="utf-8") as handle:
handle.write("doc_len_tokens\tcount\tfraction\tcumulative_fraction\n")
for length in nz:
count = int(hist[length])
cumulative += count
handle.write(
f"{int(length)}\t{count}\t{count / total_docs:.12f}\t{cumulative / total_docs:.12f}\n"
)
def expand_file_specs(specs: list[str]) -> list[Path]:
files: list[Path] = []
seen: set[str] = set()
for spec in specs:
matches = sorted(glob.glob(spec))
for match in matches:
norm = os.path.normpath(match)
if norm in seen:
continue
seen.add(norm)
files.append(Path(norm))
return files
def resolve_file_selection(args: argparse.Namespace) -> tuple[list[Path], str]:
if args.files:
files = expand_file_specs(args.files)
return files, "files=" + ", ".join(args.files)
if args.pattern:
files = expand_file_specs([args.pattern])
return files, f"pattern={args.pattern}"
split_to_glob = {
"train": "fineweb_train_*.bin",
"val": "fineweb_val_*.bin",
"test": "fineweb_val_*.bin",
}
pattern = str(Path(args.data_path) / split_to_glob[args.split])
files = expand_file_specs([pattern])
return files, f"split={args.split} pattern={pattern}"
def main() -> None:
args = parse_args()
all_files, file_selection = resolve_file_selection(args)
if not all_files:
raise FileNotFoundError(f"No files found for selection: {file_selection}")
total_shards_matched = len(all_files)
if args.limit_shards > 0:
files = all_files[: args.limit_shards]
else:
files = all_files
drop_final_partial_doc = args.limit_shards > 0 and len(files) < total_shards_matched
result = scan_doc_lengths(
files=files,
bos_id=args.bos_id,
exclude_bos=args.exclude_bos,
progress_every=args.progress_every,
total_shards_matched=total_shards_matched,
file_selection=file_selection,
drop_final_partial_doc=drop_final_partial_doc,
)
print_summary(result, exclude_bos=args.exclude_bos)
if args.write_tsv is not None:
write_tsv(args.write_tsv, result.stats.hist, result.stats.total_docs)
print()
print(f"wrote_tsv: {args.write_tsv}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment