|
#!/usr/bin/env python3 |
|
from __future__ import annotations |
|
|
|
import argparse |
|
import glob |
|
import math |
|
import os |
|
import sys |
|
from dataclasses import dataclass |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
|
|
HEADER_INTS = 256 |
|
DATAFILE_MAGIC = 20240520 |
|
DATAFILE_VERSION = 1 |
|
DEFAULT_BOS_ID = 1 |
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
parser = argparse.ArgumentParser( |
|
description=( |
|
"Compute the document-length distribution for FineWeb shard sets. " |
|
"This matches train_gpt.py's shard format and BOS-delimited document boundaries." |
|
) |
|
) |
|
parser.add_argument( |
|
"--data-path", |
|
default=os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024"), |
|
help="Dataset directory. Default: DATA_PATH or ./data/datasets/fineweb10B_sp1024", |
|
) |
|
parser.add_argument( |
|
"--split", |
|
choices=("train", "val", "test"), |
|
default="train", |
|
help=( |
|
"Named shard set under --data-path. " |
|
"'test' is an alias for the held-out fineweb_val_* shards in this repo." |
|
), |
|
) |
|
parser.add_argument( |
|
"--pattern", |
|
default=None, |
|
help="Optional shard glob override. Takes precedence over --split.", |
|
) |
|
parser.add_argument( |
|
"--files", |
|
nargs="+", |
|
default=None, |
|
help=( |
|
"Explicit file paths and/or glob patterns to scan. " |
|
"Takes precedence over --pattern and --split." |
|
), |
|
) |
|
parser.add_argument( |
|
"--bos-id", |
|
type=int, |
|
default=DEFAULT_BOS_ID, |
|
help="Token id that marks the start of each document. Default: 1", |
|
) |
|
parser.add_argument( |
|
"--exclude-bos", |
|
action="store_true", |
|
help="Report content length excluding the leading BOS token from each document.", |
|
) |
|
parser.add_argument( |
|
"--limit-shards", |
|
type=int, |
|
default=0, |
|
help="Only scan the first N matched shards. Useful for quick smoke tests.", |
|
) |
|
parser.add_argument( |
|
"--progress-every", |
|
type=int, |
|
default=10, |
|
help="Print progress every N shards to stderr. Use 0 to disable.", |
|
) |
|
parser.add_argument( |
|
"--write-tsv", |
|
type=Path, |
|
default=None, |
|
help="Optional path to write the exact per-length counts as TSV.", |
|
) |
|
return parser.parse_args() |
|
|
|
|
|
def load_data_shard(file: Path) -> np.ndarray: |
|
header_bytes = HEADER_INTS * np.dtype("<i4").itemsize |
|
token_bytes = np.dtype("<u2").itemsize |
|
header = np.fromfile(file, dtype="<i4", count=HEADER_INTS) |
|
if ( |
|
header.size != HEADER_INTS |
|
or int(header[0]) != DATAFILE_MAGIC |
|
or int(header[1]) != DATAFILE_VERSION |
|
): |
|
raise ValueError(f"Unexpected shard header for {file}") |
|
num_tokens = int(header[2]) |
|
expected_size = header_bytes + num_tokens * token_bytes |
|
if file.stat().st_size != expected_size: |
|
raise ValueError(f"Shard size mismatch for {file}: expected {expected_size} bytes") |
|
tokens = np.fromfile(file, dtype="<u2", count=num_tokens, offset=header_bytes) |
|
if tokens.size != num_tokens: |
|
raise ValueError(f"Short read for {file}") |
|
return tokens |
|
|
|
|
|
@dataclass |
|
class LengthStats: |
|
hist: np.ndarray |
|
total_docs: int = 0 |
|
total_tokens: int = 0 |
|
min_len: int | None = None |
|
max_len: int = 0 |
|
|
|
def add(self, lengths: np.ndarray) -> None: |
|
lengths = np.asarray(lengths, dtype=np.int64) |
|
if lengths.size == 0: |
|
return |
|
if np.any(lengths < 0): |
|
raise ValueError("Encountered negative document length after adjustment") |
|
counts = np.bincount(lengths) |
|
if counts.size > self.hist.size: |
|
grown = np.zeros((counts.size,), dtype=np.int64) |
|
grown[: self.hist.size] = self.hist |
|
self.hist = grown |
|
self.hist[: counts.size] += counts |
|
self.total_docs += int(lengths.size) |
|
self.total_tokens += int(lengths.sum(dtype=np.int64)) |
|
shard_min = int(lengths.min()) |
|
shard_max = int(lengths.max()) |
|
self.min_len = shard_min if self.min_len is None else min(self.min_len, shard_min) |
|
self.max_len = max(self.max_len, shard_max) |
|
|
|
|
|
@dataclass |
|
class ScanResult: |
|
stats: LengthStats |
|
shards_scanned: int |
|
total_shards_matched: int |
|
total_tokens_seen: int |
|
dropped_tail_tokens: int |
|
file_selection: str |
|
|
|
|
|
def maybe_adjust(lengths: np.ndarray, exclude_bos: bool) -> np.ndarray: |
|
if not exclude_bos: |
|
return lengths |
|
return lengths - 1 |
|
|
|
|
|
def scan_doc_lengths( |
|
files: list[Path], |
|
bos_id: int, |
|
exclude_bos: bool, |
|
progress_every: int, |
|
total_shards_matched: int, |
|
file_selection: str, |
|
drop_final_partial_doc: bool, |
|
) -> ScanResult: |
|
stats = LengthStats(hist=np.zeros((1,), dtype=np.int64)) |
|
current_doc_len: int | None = None |
|
total_tokens_seen = 0 |
|
|
|
for shard_idx, file in enumerate(files, start=1): |
|
tokens = load_data_shard(file) |
|
total_tokens_seen += int(tokens.size) |
|
bos_positions = np.flatnonzero(tokens == bos_id) |
|
|
|
if bos_positions.size == 0: |
|
if current_doc_len is None: |
|
raise ValueError( |
|
f"No BOS token found before or inside {file}. " |
|
"Expected the matched shard set to begin at a document boundary." |
|
) |
|
current_doc_len += int(tokens.size) |
|
else: |
|
first_bos = int(bos_positions[0]) |
|
if current_doc_len is None: |
|
if first_bos != 0: |
|
raise ValueError( |
|
f"{file} begins with {first_bos} tokens before the first BOS token. " |
|
"Expected the matched shard set to begin at a document boundary." |
|
) |
|
else: |
|
stats.add(maybe_adjust(np.asarray([current_doc_len + first_bos], dtype=np.int64), exclude_bos)) |
|
|
|
if bos_positions.size > 1: |
|
stats.add(maybe_adjust(np.diff(bos_positions).astype(np.int64, copy=False), exclude_bos)) |
|
|
|
current_doc_len = int(tokens.size - int(bos_positions[-1])) |
|
|
|
if progress_every and (shard_idx % progress_every == 0 or shard_idx == len(files)): |
|
tail = 0 if current_doc_len is None else current_doc_len |
|
print( |
|
f"[progress] shards={shard_idx}/{len(files)} docs={stats.total_docs:,} " |
|
f"tokens_seen={total_tokens_seen:,} open_doc_tokens={tail:,}", |
|
file=sys.stderr, |
|
flush=True, |
|
) |
|
|
|
dropped_tail_tokens = 0 |
|
if current_doc_len is not None: |
|
if drop_final_partial_doc: |
|
dropped_tail_tokens = current_doc_len |
|
else: |
|
stats.add(maybe_adjust(np.asarray([current_doc_len], dtype=np.int64), exclude_bos)) |
|
|
|
return ScanResult( |
|
stats=stats, |
|
shards_scanned=len(files), |
|
total_shards_matched=total_shards_matched, |
|
total_tokens_seen=total_tokens_seen, |
|
dropped_tail_tokens=dropped_tail_tokens, |
|
file_selection=file_selection, |
|
) |
|
|
|
|
|
def nearest_rank_quantile(hist: np.ndarray, total_docs: int, q: float) -> int: |
|
if total_docs <= 0: |
|
raise ValueError("Quantiles require at least one document") |
|
rank = max(1, int(math.ceil(q * total_docs))) |
|
cdf = np.cumsum(hist, dtype=np.int64) |
|
return int(np.searchsorted(cdf, rank, side="left")) |
|
|
|
|
|
def print_summary(result: ScanResult, exclude_bos: bool) -> None: |
|
stats = result.stats |
|
if stats.total_docs == 0: |
|
raise ValueError("No complete documents were found") |
|
|
|
quantiles = { |
|
"p50": nearest_rank_quantile(stats.hist, stats.total_docs, 0.50), |
|
"p90": nearest_rank_quantile(stats.hist, stats.total_docs, 0.90), |
|
"p95": nearest_rank_quantile(stats.hist, stats.total_docs, 0.95), |
|
"p99": nearest_rank_quantile(stats.hist, stats.total_docs, 0.99), |
|
"p99.9": nearest_rank_quantile(stats.hist, stats.total_docs, 0.999), |
|
} |
|
|
|
length_mode = ( |
|
"content tokens per document (leading BOS excluded)" |
|
if exclude_bos |
|
else "stored tokens per document (leading BOS included)" |
|
) |
|
print(f"file_selection: {result.file_selection}") |
|
print(f"length_mode: {length_mode}") |
|
print(f"shards_scanned: {result.shards_scanned}/{result.total_shards_matched}") |
|
print(f"documents: {stats.total_docs:,}") |
|
print(f"tokens_accounted_for: {stats.total_tokens:,}") |
|
print(f"tokens_seen_in_shards: {result.total_tokens_seen:,}") |
|
if result.dropped_tail_tokens: |
|
print(f"dropped_tail_tokens: {result.dropped_tail_tokens:,}") |
|
print(f"min_len: {stats.min_len}") |
|
print(f"mean_len: {stats.total_tokens / stats.total_docs:.3f}") |
|
for name, value in quantiles.items(): |
|
print(f"{name}_len: {value}") |
|
print(f"max_len: {stats.max_len}") |
|
print() |
|
print("log2_histogram:") |
|
if stats.hist[0]: |
|
count = int(stats.hist[0]) |
|
frac = 100.0 * count / stats.total_docs |
|
print(f" [0, 0]: count={count:,} frac={frac:.3f}%") |
|
low = 1 |
|
if stats.min_len is not None and stats.min_len > 1: |
|
low = 1 << int(math.floor(math.log2(stats.min_len))) |
|
while low <= stats.max_len: |
|
high = min(stats.max_len, 2 * low - 1) |
|
count = int(stats.hist[low : high + 1].sum(dtype=np.int64)) |
|
frac = 100.0 * count / stats.total_docs |
|
if count: |
|
print(f" [{low}, {high}]: count={count:,} frac={frac:.3f}%") |
|
low *= 2 |
|
|
|
|
|
def write_tsv(path: Path, hist: np.ndarray, total_docs: int) -> None: |
|
path.parent.mkdir(parents=True, exist_ok=True) |
|
nz = np.flatnonzero(hist) |
|
cumulative = 0 |
|
with path.open("w", encoding="utf-8") as handle: |
|
handle.write("doc_len_tokens\tcount\tfraction\tcumulative_fraction\n") |
|
for length in nz: |
|
count = int(hist[length]) |
|
cumulative += count |
|
handle.write( |
|
f"{int(length)}\t{count}\t{count / total_docs:.12f}\t{cumulative / total_docs:.12f}\n" |
|
) |
|
|
|
|
|
def expand_file_specs(specs: list[str]) -> list[Path]: |
|
files: list[Path] = [] |
|
seen: set[str] = set() |
|
for spec in specs: |
|
matches = sorted(glob.glob(spec)) |
|
for match in matches: |
|
norm = os.path.normpath(match) |
|
if norm in seen: |
|
continue |
|
seen.add(norm) |
|
files.append(Path(norm)) |
|
return files |
|
|
|
|
|
def resolve_file_selection(args: argparse.Namespace) -> tuple[list[Path], str]: |
|
if args.files: |
|
files = expand_file_specs(args.files) |
|
return files, "files=" + ", ".join(args.files) |
|
|
|
if args.pattern: |
|
files = expand_file_specs([args.pattern]) |
|
return files, f"pattern={args.pattern}" |
|
|
|
split_to_glob = { |
|
"train": "fineweb_train_*.bin", |
|
"val": "fineweb_val_*.bin", |
|
"test": "fineweb_val_*.bin", |
|
} |
|
pattern = str(Path(args.data_path) / split_to_glob[args.split]) |
|
files = expand_file_specs([pattern]) |
|
return files, f"split={args.split} pattern={pattern}" |
|
|
|
|
|
def main() -> None: |
|
args = parse_args() |
|
all_files, file_selection = resolve_file_selection(args) |
|
if not all_files: |
|
raise FileNotFoundError(f"No files found for selection: {file_selection}") |
|
|
|
total_shards_matched = len(all_files) |
|
if args.limit_shards > 0: |
|
files = all_files[: args.limit_shards] |
|
else: |
|
files = all_files |
|
drop_final_partial_doc = args.limit_shards > 0 and len(files) < total_shards_matched |
|
|
|
result = scan_doc_lengths( |
|
files=files, |
|
bos_id=args.bos_id, |
|
exclude_bos=args.exclude_bos, |
|
progress_every=args.progress_every, |
|
total_shards_matched=total_shards_matched, |
|
file_selection=file_selection, |
|
drop_final_partial_doc=drop_final_partial_doc, |
|
) |
|
print_summary(result, exclude_bos=args.exclude_bos) |
|
if args.write_tsv is not None: |
|
write_tsv(args.write_tsv, result.stats.hist, result.stats.total_docs) |
|
print() |
|
print(f"wrote_tsv: {args.write_tsv}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |