Skip to content

Instantly share code, notes, and snippets.

@aug2uag
Created April 26, 2025 00:40
Show Gist options
  • Save aug2uag/c8458aa453f7582b0a7fbad1de725ecc to your computer and use it in GitHub Desktop.
Save aug2uag/c8458aa453f7582b0a7fbad1de725ecc to your computer and use it in GitHub Desktop.
JAX model runner
// model runner
"""
JAX-based Model Runner for ARC Challenges
This module provides JAX implementations of the components needed for test-time training
and inference on ARC challenges.
"""
import json
import os, sys
import bz2
import pickle
import numpy as np
import jax
import jax.numpy as jnp
from jax import lax
import flax
from flax import linen as nn
from flax.training import train_state
import optax
from typing import Dict, List, Tuple, Optional, Any, Callable, Union
import logging
from tqdm.auto import tqdm
from transformers import AutoTokenizer
from transformers.generation import FlaxGenerationMixin
from dataclasses import dataclass
# Set up logging
logger = logging.getLogger(__name__)
def indices_required_for_merges(keep_indices, vocab, merges):
"""Determine which token indices are required for BPE merges."""
merges_lookup = {}
for m in merges:
a, b = m.split(' ') if isinstance(m, str) else m
key = vocab[f'{a}{b}']
if key not in merges_lookup: merges_lookup[key] = set()
merges_lookup[key].add(vocab[a])
merges_lookup[key].add(vocab[b])
to_process = list(keep_indices)
while len(to_process):
for w in merges_lookup.get(to_process.pop(), []):
if w not in keep_indices:
keep_indices[w] = None
to_process.append(w)
return keep_indices
def remove_unused_merges(merges, vocab):
"""Remove BPE merges that use tokens not in the vocabulary."""
return [f'{a} {b}' for a, b in [m.split(' ') if isinstance(m, str) else m for m in merges] if all(w in vocab for w in [a, b, a + b])]
def map_special_tokens(data, mapping=None):
"""Map special tokens using the provided mapping."""
tokens = set()
if isinstance(data, dict):
special = data.get('special_tokens')
if special is not None:
for v in special.values():
tokens.update(v['ids'])
if mapping is not None:
v['ids'] = [mapping.get(i) for i in v['ids'] if i in mapping]
for v in (data.values() if isinstance(data, dict) else data if isinstance(data, list) else []):
tokens.update(map_special_tokens(v, mapping))
return tokens
def remove_tokenizer_normalizer(tokenizer):
"""Remove normalizer from tokenizer to preserve exact input."""
tokenizer_json = json.loads(tokenizer.backend_tokenizer.to_str())
if tokenizer_json.get('normalizer') is not None:
tokenizer_json['normalizer'] = None
tokenizer.backend_tokenizer = AutoTokenizer.from_pretrained(
None, tokenizer_file=json.dumps(tokenizer_json)
).backend_tokenizer
return tokenizer
def shrink_tokenizer_vocab(tokenizer, keep_indices, keep_special_tokens, keep_token_order):
"""Shrink tokenizer vocabulary to only keep necessary tokens."""
tokenizer_json = json.loads(tokenizer.backend_tokenizer.to_str())
assert tokenizer_json['model']['type'] == "BPE"
if keep_special_tokens:
keep_indices.update({k: None for k in tokenizer.all_special_ids})
keep_indices.update({k: None for k in map_special_tokens(tokenizer_json.get('post_processor'))})
keep_indices = indices_required_for_merges(keep_indices, tokenizer_json['model']['vocab'], tokenizer_json['model']['merges'])
if keep_token_order: keep_indices = sorted(keep_indices)
mapping = {old: new for new, old in enumerate(keep_indices)}
tokenizer_json['model']['vocab'] = {k: mapping[v] for k, v in tokenizer_json['model']['vocab'].items() if v in mapping}
tokenizer_json['model']['merges'] = remove_unused_merges(tokenizer_json['model']['merges'], tokenizer_json['model']['vocab'])
special_tokens_order = [t['id'] for t in tokenizer_json['added_tokens']]
assert special_tokens_order == sorted(special_tokens_order)
tokenizer_json['added_tokens'] = sorted([{**t, 'id': mapping[t['id']]} for t in tokenizer_json['added_tokens'] if t['id'] in mapping], key=lambda t: t['id'])
map_special_tokens(tokenizer_json.get('post_processor'), mapping)
# Create new tokenizer from updated json
new_tokenizer = AutoTokenizer.from_pretrained(None, tokenizer_file=json.dumps(tokenizer_json))
return new_tokenizer, mapping, keep_indices
def shrink_model_embeddings(model, keep_indices, mapping):
"""Adjust model embeddings to fit the shrunk vocabulary."""
config = model.config
# Create a new version of embeddings with only the kept indices
with jax.default_device(jax.devices("cpu")[0]):
params = model.params
# Extract original embedding weights
embed_weights = params['transformer']['wte']['embedding']
lm_head_weights = params['lm_head']['kernel'] if 'lm_head' in params else None
# Create new embedding weights with only the kept indices
row_select = jnp.array(list(keep_indices), dtype=jnp.int32)
new_embed_weights = embed_weights[row_select]
# Update model parameters
params['transformer']['wte']['embedding'] = new_embed_weights
if lm_head_weights is not None:
new_lm_head_weights = lm_head_weights[row_select]
params['lm_head']['kernel'] = new_lm_head_weights
# Update configurations
for config_obj in [config]:
for k, v in list(config_obj.to_dict().items()):
if k.endswith('token_id'):
setattr(config_obj, k, [mapping.get(t) for t in v] if isinstance(v, list) else mapping.get(v))
# Create new model with updated parameters
new_model = type(model)(config=config)
new_model.params = params
return new_model
def shrink_embeddings(model, tokenizer, corpus=None, keep_token_ids=[], keep_tokens=[], remove_token_ids=[], keep_model_tokens=True, keep_special_tokens=True, keep_normalizer=False, keep_token_order=True):
"""Shrink model embeddings and tokenizer to only keep necessary tokens."""
if not keep_normalizer:
tokenizer = remove_tokenizer_normalizer(tokenizer)
from collections import OrderedDict # use as OrderedSet
keep_indices = OrderedDict()
keep_indices.update({k: None for k in keep_token_ids})
keep_indices.update({tokenizer.vocab[t]: None for t in keep_tokens})
if corpus is not None:
keep_indices.update({k: None for k in tokenizer(corpus)['input_ids']})
if keep_model_tokens:
for config in [model.config]:
for k, v in config.to_dict().items():
if k.endswith('token_id'):
keep_indices.update({k: None for k in (v if isinstance(v, list) else [v])})
keep_indices.pop(None, None)
for idx in remove_token_ids:
keep_indices.pop(idx, None)
new_tokenizer, mapping, keep_indices = shrink_tokenizer_vocab(tokenizer, keep_indices, keep_special_tokens, keep_token_order)
new_model = shrink_model_embeddings(model, keep_indices, mapping=mapping)
return new_model, new_tokenizer, mapping
def fix_dtypes(model, fix_weights=True, fix_quant_states=True):
"""Ensure all model parameters have consistent dtypes."""
# In JAX, we don't need to iterate through modules like in PyTorch
# Instead, we create a tree map function to fix dtypes in the parameter tree
params = model.params
default_dtype = jnp.float32 # or model's default dtype
def fix_param_dtype(param):
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
if fix_weights and param.dtype != default_dtype:
return param.astype(default_dtype)
return param
# Apply the function to all parameters
new_params = jax.tree_map(fix_param_dtype, params)
model.params = new_params
return model
def merge_peft_into_base(model):
"""Merge LoRA parameters into base model."""
logger.info('*** Merge peft model into base model...')
# Get base params and LoRA params
base_params = model.params
lora_params = model.lora_params
# Create merged parameters
merged_params = merge_lora_weights(base_params, lora_params, model.config.lora_config)
# Create new model with merged parameters
new_model = type(model)(config=model.config)
new_model.params = merged_params
return fix_dtypes(new_model)
def merge_lora_weights(base_params, lora_params, lora_config):
"""Merge LoRA weights into base weights."""
merged_params = base_params.copy()
for module_name in lora_config.target_modules:
# Find all parameters that match this module
for path, param in jax.tree_util.tree_leaves_with_path(base_params):
path_str = '/'.join([str(p) for p in path])
if module_name in path_str:
# Get corresponding LoRA parameters
lora_a_path = path + ('lora_A',)
lora_b_path = path + ('lora_B',)
lora_a = jax.tree_util.tree_get(lora_params, lora_a_path)
lora_b = jax.tree_util.tree_get(lora_params, lora_b_path)
if lora_a is not None and lora_b is not None:
# Compute LoRA update: (B x A) * alpha / r
lora_update = (lora_b @ lora_a) * (lora_config.lora_alpha / lora_config.r)
# Update base parameter
updated_param = param + lora_update
# Set updated parameter in merged params
merged_params = jax.tree_util.tree_set(merged_params, path, updated_param)
return merged_params
def save_model(store_path, model=None, tokenizer=None, merge=False):
"""Save model and tokenizer to disk."""
if merge and model is not None:
model = merge_peft_into_base(model)
if store_path is not None:
assert model is not None or tokenizer is not None
logger.info(f"*** Saving{' merged' if merge else ''} model/tokenizer to '{store_path}'...")
os.makedirs(store_path, exist_ok=True)
if model is not None:
# Save model config and parameters
with open(os.path.join(store_path, 'config.json'), 'w') as f:
json.dump(model.config.to_dict(), f)
# Save model parameters
with open(os.path.join(store_path, 'model.safetensors'), 'wb') as f:
params_bytes = flax.serialization.msgpack_serialize(model.params)
f.write(params_bytes)
if tokenizer is not None:
tokenizer.save_pretrained(store_path)
# Remove unnecessary tokenizer model file if it exists
to_delete = os.path.join(store_path, 'tokenizer.model')
if os.path.isfile(to_delete):
os.remove(to_delete)
return model
def is_unsloth_model(model):
"""Check if model is an unsloth model."""
# JAX doesn't have unsloth models, but we'll keep the function for compatibility
return False
def is_peft_model(model):
"""Check if model is a PEFT model."""
return hasattr(model, 'lora_params')
def download_model(repo_id, store_path, get_name=lambda n: os.path.join(n.replace('/', '--'), 'transformers', 'default', '1')):
"""Download a model from Hugging Face or use local cache."""
import os
if os.path.exists(repo_id):
return repo_id
model_path = os.path.join(store_path, get_name(repo_id))
if not os.path.exists(model_path):
from huggingface_hub import snapshot_download
download_path = snapshot_download(repo_id=repo_id)
os.makedirs(os.path.split(model_path)[0], exist_ok=True)
os.symlink(download_path, model_path, target_is_directory=True)
return model_path
def get_and_fix_peft_weights(store):
"""Load and fix PEFT state weights."""
logger.info(f"*** Load peft state_dict from '{store}'...")
# Load PEFT weights
with open(os.path.join(store, 'lora_params.msgpack'), 'rb') as f:
state_dict = flax.serialization.msgpack_deserialize(f.read())
# Filter out unwanted keys
filtered_dict = {}
for k, v in state_dict.items():
if 'modules_to_save' not in k:
filtered_dict[k] = v
return filtered_dict
def set_peft_weights(model, state_dict):
"""Set PEFT weights in the model."""
logger.info(f"*** Set model state_dict...")
# Set lora parameters
model.lora_params = state_dict
return model
class LoraConfig:
"""Configuration for LoRA."""
def __init__(self, r=8, target_modules=None, lora_alpha=16, lora_dropout=0.0,
bias="none", use_gradient_checkpointing=False, random_state=42,
use_rslora=False, loftq_config=None):
self.r = r
self.target_modules = target_modules or []
self.lora_alpha = lora_alpha
self.lora_dropout = lora_dropout
self.bias = bias
self.use_gradient_checkpointing = use_gradient_checkpointing
self.random_state = random_state
self.use_rslora = use_rslora
self.loftq_config = loftq_config
def apply_lora_to_model(model, lora_config):
"""Apply LoRA to model parameters."""
# Set up random key
rng = jax.random.PRNGKey(lora_config.random_state)
# Initialize LoRA parameters
lora_params = {}
for module_name in lora_config.target_modules:
# Find all parameters that match this module
for path, param in jax.tree_util.tree_leaves_with_path(model.params):
path_str = '/'.join([str(p) for p in path])
if module_name in path_str and isinstance(param, jnp.ndarray):
# For weight matrices, add LoRA parameters
if len(param.shape) == 2: # Weight matrix
# Initialize A and B matrices
rng, key_a, key_b = jax.random.split(rng, 3)
if lora_config.use_rslora:
# RS-LoRA initialization
lora_a = jax.random.normal(key_a, (lora_config.r, param.shape[1])) * 0.1
lora_b = jax.random.normal(key_b, (param.shape[0], lora_config.r)) * 0.1
else:
# Standard LoRA initialization
lora_a = jax.random.normal(key_a, (lora_config.r, param.shape[1])) * 0.01
lora_b = jnp.zeros((param.shape[0], lora_config.r))
# Add to lora_params
lora_params = jax.tree_util.tree_set(lora_params, path + ('lora_A',), lora_a)
lora_params = jax.tree_util.tree_set(lora_params, path + ('lora_B',), lora_b)
# Add lora_params to model
model.lora_params = lora_params
model.config.lora_config = lora_config
return model
def prepare_model(model, mode, tokenizer=None, formatter=None, shrink_embedding=False,
dequantize=False, peft=[], local_files_only=False, add_special_tokens={},
set_pad_token=None, keep_tokens=[], keep_normalizer=None,
peft_trainable=True, device_map=None, tf_grad_cp=True, tf_use_fa2=True, **kwargs):
"""Prepare model and tokenizer with various optimizations."""
if isinstance(model, str):
assert tokenizer is None
logger.info(f"*** Load base model and tokenizer from '{model}'...")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model, local_files_only=local_files_only, **kwargs)
# Load model based on mode
if mode != 'tokenizer_only':
from transformers import FlaxAutoModelForCausalLM
# JAX doesn't support device mapping like PyTorch, ignore device_map
model_load_args = {}
# Load model
model = FlaxAutoModelForCausalLM.from_pretrained(model, **model_load_args)
else:
model = None
# Add special tokens if needed
if add_special_tokens:
tokenizer.add_special_tokens(add_special_tokens)
# Set pad token if needed
if set_pad_token is not None:
tokenizer.pad_token = set_pad_token
# Set formatter
if formatter is not None and not hasattr(formatter, 'corpus'):
formatter = formatter(tokenizer=tokenizer)
# Apply embedding shrinking if needed
if (shrink_embedding < len(tokenizer.vocab) if type(shrink_embedding) == int else shrink_embedding) or keep_normalizer is False:
logger.info('*** Shrink embedding...')
embedding_size_before_shrink = len(tokenizer.vocab)
model, tokenizer, mapping = shrink_embeddings(
model, tokenizer, formatter.get_corpus(),
keep_tokens=keep_tokens, keep_normalizer=keep_normalizer
)
logger.info(f'*** -> Reduced embedding size from {embedding_size_before_shrink} to {len(mapping)} words.')
# Apply LoRA or other parameter-efficient methods
if len(peft):
peft_trained = True if is_peft_model(model) else None
for i, m in enumerate(peft):
if peft_trained is True:
model, peft_trained = merge_peft_into_base(model), None
if isinstance(m, str):
if peft_trained is False:
_, peft_trained = load_peft_state(model, m), True
else:
logger.info(f"*** Load peft model from '{m}'...")
state_dict = get_and_fix_peft_weights(m)
model = set_peft_weights(model, state_dict)
peft_trained = True
else:
assert peft_trained is None
if isinstance(m, dict):
logger.info('*** Create new peft model...')
lora_config = LoraConfig(**m)
model = apply_lora_to_model(model, lora_config)
peft_trained = False
else:
assert m is None
return model, tokenizer, formatter
@dataclass
class TrainingConfig:
"""Configuration for training."""
batch_size: int = 1
learning_rate: float = 1e-4
weight_decay: float = 0.01
num_epochs: int = 1
warmup_steps: int = 100
max_steps: int = -1
seed: int = 42
logging_steps: int = 10
save_strategy: str = "no"
output_dir: str = "./output"
gradient_accumulation_steps: int = 1
embedding_learning_rate: float = 1e-5
lr_scheduler_type: str = "linear"
optim: str = "adamw"
def create_optimizer(config):
"""Create optimizer based on configuration."""
if config.lr_scheduler_type == "linear":
schedule_fn = optax.linear_schedule(
init_value=config.learning_rate,
end_value=0,
transition_steps=config.max_steps,
transition_begin=config.warmup_steps
)
elif config.lr_scheduler_type == "cosine":
schedule_fn = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=config.learning_rate,
warmup_steps=config.warmup_steps,
decay_steps=config.max_steps,
end_value=0.0
)
else: # constant
schedule_fn = config.learning_rate
# Create optimizer with weight decay
optimizer = optax.adamw(
learning_rate=schedule_fn,
weight_decay=config.weight_decay,
b1=0.9,
b2=0.999,
eps=1e-8
)
return optimizer
def training_run(model, formatter, dataset, train_args, max_seq_length, merge=False,
store=None, packing=False, grad_acc_fix=False, optimizers=None):
"""Run training on the model."""
assert merge is False, "Merge after training not supported in JAX implementation"
# Create training config
config = TrainingConfig(**train_args)
# Prepare dataset
formatter.tokenizer.padding_side = 'right'
train_data = dataset.as_list(formatter)
# Create optimizer
if optimizers is None:
optimizer = create_optimizer(config)
else:
optimizer = optimizers[0]
# Create training state
rng = jax.random.PRNGKey(config.seed)
# Define training step
def train_step(state, batch, rng):
"""Single training step."""
def loss_fn(params):
logits = model.apply(
{"params": params},
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
deterministic=False,
rngs={"dropout": rng}
).logits
# Shift logits and labels for next token prediction
shift_logits = logits[:, :-1]
shift_labels = batch["labels"][:, 1:]
# Calculate loss (with label smoothing if specified)
loss = optax.softmax_cross_entropy(
shift_logits,
jax.nn.one_hot(shift_labels, shift_logits.shape[-1])
)
# Mask out padding tokens
mask = (shift_labels != -100)
loss = (loss * mask).sum() / mask.sum()
return loss
# Get gradients
loss, grads = jax.value_and_grad(loss_fn)(state.params)
# Update parameters
new_state = state.apply_gradients(grads=grads)
return new_state, loss
# Create initial training state
state = train_state.TrainState.create(
apply_fn=model.apply,
params=model.params,
tx=optimizer
)
# JIT compile the training step
jit_train_step = jax.jit(train_step)
# Training loop
rng, train_rng = jax.random.split(rng)
logger.info("*** Start training run...")
# Calculate number of training steps
num_examples = len(train_data)
steps_per_epoch = num_examples // config.batch_size
total_steps = steps_per_epoch * config.num_epochs
if config.max_steps > 0:
total_steps = min(total_steps, config.max_steps)
# Training loop
step = 0
epoch = 0
with tqdm(total=total_steps, desc="Training") as progress_bar:
while step < total_steps:
# Shuffle data for new epoch
if step % steps_per_epoch == 0:
rng, data_rng = jax.random.split(rng)
shuffled_idx = jax.random.permutation(data_rng, jnp.arange(num_examples))
train_data = [train_data[i] for i in shuffled_idx]
epoch += 1
# Get batch
batch_start = (step % steps_per_epoch) * config.batch_size
batch_end = min(batch_start + config.batch_size, num_examples)
batch_data = train_data[batch_start:batch_end]
# Prepare batch
input_texts = [item["text"] for item in batch_data]
encodings = formatter.tokenizer(
input_texts,
max_length=max_seq_length,
padding="max_length",
truncation=True,
return_tensors="jax"
)
# Prepare labels
input_ids = encodings["input_ids"]
labels = jnp.where(
input_ids == formatter.tokenizer.pad_token_id,
-100, # Ignore padding tokens in loss
input_ids
)
# Modify labels based on formatter's data collator if needed
collator = formatter.get_data_collator()
if collator is not None:
# Implement collator logic specific to your formatter
pass
# Prepare batch for JAX
batch = {
"input_ids": input_ids,
"attention_mask": encodings["attention_mask"],
"labels": labels
}
# Train step
rng, step_rng = jax.random.split(rng)
state, loss = jit_train_step(state, batch, step_rng)
if step % config.logging_steps == 0:
logger.info(f"Step {step}: loss = {loss.item():.4f}")
step += 1
progress_bar.update(1)
progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
# Update model parameters
model.params = state.params
# Save model if needed
if store is not None:
save_model(store, model, formatter.tokenizer, merge=merge)
logger.info("*** Training completed.")
return model, {"metrics": {"train_runtime": None}} # For compatibility
def inference_load(store, keys=True, result_dict=None, always_read_from_file=False):
"""Load inference results from store."""
if result_dict is None:
result_dict = {}
if store is not None:
if keys is True:
keys = os.listdir(store)
for key in keys:
if always_read_from_file or key not in result_dict:
try:
with bz2.BZ2File(os.path.join(store, key)) as f:
result_dict[key] = pickle.load(f)
except:
continue
return result_dict
def inference_save(store, key, outputs):
"""Save inference results to store."""
if store is not None:
os.makedirs(store, exist_ok=True)
with bz2.BZ2File(os.path.join(store, key), 'w') as f:
pickle.dump(outputs, f)
class Decoder:
"""Decoder for model outputs."""
def __init__(self, formatter, dataset, n_guesses, max_outputs=None, frac_score=False,
quiet=False, name='', additional_decoders=None, prob_baseline=None):
self.formatter = formatter
self.dataset = dataset
self.n_guesses = n_guesses
self.decoded_results = {}
self.correct_solutions = {}
self.keys_lim = set()
self.keys_all = set()
self.mult_cnt = {}
self.keys_cnt = {}
self.frac_score = frac_score
self.max_outputs = max_outputs
self.quiet = quiet
self.input_len = [{} if formatter is not None and formatter.tokenizer is None else ds.get_lengths(formatter, name='input') for ds in [dataset, dataset.mod(np.transpose, keep_key=True)]]
self.reply_len = [{} if formatter is not None and formatter.tokenizer is None else ds.get_lengths(formatter, name='reply') for ds in [dataset, dataset.mod(np.transpose, keep_key=True)]]
self.additional_decoders = additional_decoders
self.name = name
self.prob_tracker = {}
self.prob_tracker_best = {}
self.prob_baseline = prob_baseline
def score(self, *to_score):
"""Score the decoded results."""
scores = [(sum(1/self.mult_cnt[k.split('_')[0]] for k in s) if self.frac_score else len(s)) for s in to_score]
score_cnt = len(self.mult_cnt if self.frac_score else self.keys_cnt)
return scores, score_cnt
def from_store(self, store, **kwargs):
"""Load results from store."""
for key, outputs in inference_load(store).items():
self.process(key, outputs, **kwargs)
return self
def score_fmt(self, v):
"""Format score for display."""
return f'{v:5.1f}' if self.frac_score else f'{v:3}'
def process_single_output(self, key, output_len, decoded, print_func=print, len_info=None, device_info=None):
"""Process a single output from the model."""
inv_mod = {k: v if k.endswith('val') else self.dataset.invert_mod(v, key, inv_perm=(k.startswith('output') or k.startswith('score_all'))) for k, v in decoded.items()}
base_key = key.split('.')[0]
self.decoded_results[base_key] = self.decoded_results.get(base_key, {})
self.decoded_results[base_key][key] = inv_mod
output = inv_mod.get('output')
score = inv_mod.get('score')
# quick scoring
self.keys_cnt[base_key] = self.keys_cnt.get(base_key, 0) + 1
mult_key, mult_sub = (base_key.split('_') + ['0'])[:2]
self.mult_cnt[mult_key] = max(self.mult_cnt.get(mult_key, 0), int(mult_sub) + 1)
if len(self.dataset.replies):
correct_solution = self.dataset.replies.get(base_key)
if correct_solution is not None:
correct_solution = correct_solution[0]
self.correct_solutions[base_key] = correct_solution
is_correct = correct_solution is not None and np.array_equal(correct_solution, output)
if is_correct:
self.keys_all.add(base_key)
if self.keys_cnt[base_key] <= self.n_guesses: self.keys_lim.add(base_key)
corr_str = 'cant_decode' if output is None else 'sol_unknown' if correct_solution is None else 'ALL_CORRECT' if is_correct else 'bad_xy_size' if np.shape(correct_solution)!=np.shape(output) else 'bad_content'
(score_lim, score_all), score_cnt = self.score(self.keys_lim, self.keys_all)
tp_arr = (key.count('transpose') + key.count('rot90')) % 2
msc = None if score is None else np.sum(score)
fsc = inv_mod.get('score_val')
if output is not None and fsc is not None:
pt = self.prob_tracker[base_key] = self.prob_tracker.get(base_key, {})
hash = tuple(map(tuple, output))
prob = pt[hash] = pt.get(hash, 0) + (np.exp(fsc) if self.prob_baseline is None else fsc - np.log(self.prob_baseline))
current_best = self.prob_tracker_best.get(base_key)
if current_best is None or current_best[0]<prob:
self.prob_tracker_best[base_key] = (prob, output)
fmt_name = f'{self.name}: ' if self.name else ''
msc_print = f'{min(-msc, 9.99999):7.5f}' if msc is not None else 'unknown'
fsc_print = f'{min(-fsc, 9.99999):7.5f}' if fsc is not None else 'unknown'
if not self.quiet: print_func(f" {fmt_name}acc: {self.score_fmt(score_lim)}/{score_cnt:3}={min(score_lim/score_cnt, 0.999):5.1%} (2-guess), {self.score_fmt(score_all)}/{score_cnt:3}={min(score_all/score_cnt, 0.999):5.1%} (any);{f' {device_info}' if device_info else ''} tok:{self.input_len[tp_arr].get(base_key, '?'):>4}+{self.reply_len[tp_arr].get(base_key, '?'):>3}>{'n/a' if output_len is None else output_len:>3} {corr_str}:{msc_print}|{fsc_print} [{key}]")
def get_current_best(self, base_key):
"""Get the current best output for a key."""
current_best = self.prob_tracker_best.get(base_key)
return None if current_best is None else current_best[1]
def process_single_decode(self, key, de_tokenized, print_func=print, **kwargs):
"""Process single decoded output."""
if len(de_tokenized)==3 and not isinstance(de_tokenized[1], float): # for backwards compatibility
output_len, *data = de_tokenized
score_val = None
else: output_len, score_val, *data = de_tokenized
if self.formatter is None:
assert len(data) == 1
decoded = [data[0]]
else: decoded = self.formatter.decode_to_array(*data)
for d in decoded: d['score_val'] = score_val
for i, dec in enumerate(decoded):
if i==0: self.process_single_output(key, output_len, dec, print_func=print_func, **kwargs)
elif self.additional_decoders:
if i-1<len(self.additional_decoders): self.additional_decoders[i-1].process_single_output(key, output_len, dec, print_func=print_func, **kwargs)
else: print_func(f'{key} no decoder available for output #{i}')
else: self.process_single_output(f'{key}.fix{i}', output_len, dec, print_func=print_func, **kwargs)
def process(self, key, de_tokenized, **kwargs):
"""Process all decoded outputs for a key."""
for i, d in enumerate(de_tokenized):
if self.max_outputs is None or i<=self.max_outputs:
self.process_single_decode(f'{key}.out{i}', d, **kwargs)
def get_unsolved_keys(self):
"""Get keys that haven't been solved yet."""
unsolved = []
for base_key, reply in self.dataset.replies.items():
if not any(np.array_equal(reply[0], s.get('output')) for s in self.decoded_results.get(base_key, {}).values()):
unsolved.append(base_key)
return unsolved
def run_selection_algo(self, selection_algorithm):
"""Run a selection algorithm on the decoded results."""
return {bk: (selection_algorithm({k: g for k, g in v.items() if g.get('output') is not None}) if any(g.get('output') is not None for g in v.values()) else []) for bk, v in self.decoded_results.items()}
def benchmark_selection_algos(self, selection_algorithms, skip_failed=True):
"""Benchmark various selection algorithms."""
import numpy as np
results = {}
logger.info('*** Benchmark selection algorithms...')
for selection_algorithm in selection_algorithms:
name = selection_algorithm.__name__
try:
selected = self.run_selection_algo(selection_algorithm)
if self.formatter is not None:
for sols in selected.values():
for s in sols:
assert self.formatter.is_valid_solution(s), f'found invalid solutions {s}'
correct_keys = {k for k, v in selected.items() if self.correct_solutions.get(k) is not None and any(np.array_equal(guess, self.correct_solutions[k]) for guess in v[:self.n_guesses])}
(score,), score_cnt = self.score(correct_keys)
results[name] = score
logger.info(f" acc: {score:5.1f}/{score_cnt:3}={score/score_cnt:6.2%} ('{name}')")
except:
logger.info(f" {'execution failed':>21} ('{name}')")
if not skip_failed: raise
return results
def calc_augmented_scores(self, model, base_keys=None, store=None, seed=0, max_len=None, make_unique=True, quiet=False, **kwargs):
"""Calculate augmented scores for outputs."""
if base_keys is None: base_keys = list(self.decoded_results.keys())
if store is not None: store = f'{store}_new' # new format is not backwards compatible, so use new folder
iterator = base_keys if quiet else tqdm(base_keys, desc='calculate augmented scores', file=sys.stdout)
for bk in iterator:
res = self.decoded_results.get(bk, {})
known_scores = {}
for k, v in sorted(res.items()):
if 'output' in v:
k_store = None if store is None else os.path.join(store, k)
id = tuple(map(tuple, v['output']))
if not (make_unique and id in known_scores):
try:
assert k_store is not None
with bz2.BZ2File(k_store) as f: known_scores[id] = pickle.load(f)
if isinstance(known_scores[id], list): known_scores[id] = dict(score_multi=known_scores[id]) # for backwards compatibility
k_store = None
except:
temp_dataset = self.dataset.__class__(
keys=[bk],
queries={bk: self.dataset.queries.get(bk)},
replies={bk: [v['output'].tolist()]},
)
temp_decoder = self.__class__(self.formatter, temp_dataset, n_guesses=self.n_guesses, quiet=True)
temp_dataset = temp_dataset.augment(**kwargs, seed=(seed+hash(k)+hash(id)) % 1024**2, quiet=True)
if max_len is not None: temp_dataset = temp_dataset.cut_to_len(formatter=self.formatter, name='input', max_len=max_len, quiet=True)
for x in temp_dataset.as_list(self.formatter): calc_score(**x, formatter=self.formatter, model=model, decoder=temp_decoder)
known_scores[id] = dict(
score_multi=[np.sum(x['score']) for x in temp_decoder.decoded_results[bk].values()],
score_multi_nl=[x['score_val'] for x in temp_decoder.decoded_results[bk].values()],
score_multi_array=np.array([x['score'] for x in temp_decoder.decoded_results[bk].values()]),
score_multi_array_cum=np.array([x['score_cum'] for x in temp_decoder.decoded_results[bk].values()]),
score_multi_array_all=np.array([x['score_all'] for x in temp_decoder.decoded_results[bk].values()]),
score_multi_array_all_cum=np.array([x['score_all_cum'] for x in temp_decoder.decoded_results[bk].values()]),
)
if k_store is not None:
os.makedirs(store, exist_ok=True)
with bz2.BZ2File(k_store, 'w') as f: pickle.dump(known_scores[id], f)
v.update(known_scores[id])
def turbo_dfs(model, logits, path, eos_token_id, max_new_tokens, max_score, max_score_greedy, temperature, suppress_tokens, score=0.0, pos=0, cache=None):
"""Depth-first search for generating sequences in a compute-efficient manner."""
# Convert JAX arrays to numpy for compatibility
logits = logits.at[0, suppress_tokens].set(-1e10) # Suppress tokens
# Calculate negative log likelihood
if temperature > 0:
logits = logits / temperature
log_probs = jax.nn.log_softmax(logits, axis=-1)
nll = -np.array(log_probs)
# Find greedy index
greedy_index = int(np.argmin(nll))
# Convert to list of tuples for sorting
nll_list = [(i, float(nll[i])) for i in range(nll.shape[0])]
# Follow precomputed path first if available
if path and len(path) > 0:
first_token = path[0]
# Swap first token with the path token
for i, (idx, _) in enumerate(nll_list):
if idx == first_token:
nll_list[0], nll_list[i] = nll_list[i], nll_list[0]
path = path[1:]
break
suffixes = []
for i, s in nll_list:
next_score = score + s
allowed_max_score = max_score_greedy if i == greedy_index else max_score
if next_score < allowed_max_score:
if i == eos_token_id:
next_suffixes = [(next_score, [], [])]
elif max_new_tokens > 1:
# Get next token logits
next_input_ids = jnp.array([[i]])
next_position_ids = jnp.array([[pos]])
# Use cache for efficient generation
if cache is not None:
next_logits, next_cache = model(
input_ids=next_input_ids,
position_ids=next_position_ids,
past_key_values=cache
)
next_suffixes = turbo_dfs(
model, next_logits[0], path, eos_token_id, max_new_tokens-1,
max_score, allowed_max_score, temperature, suppress_tokens,
score=next_score, pos=pos+1, cache=next_cache
)
else:
# No cache available, use standard forward pass
next_logits = model(
input_ids=next_input_ids,
position_ids=next_position_ids,
).logits[0]
next_suffixes = turbo_dfs(
model, next_logits, path, eos_token_id, max_new_tokens-1,
max_score, allowed_max_score, temperature, suppress_tokens,
score=next_score, pos=pos+1, cache=None
)
else:
next_suffixes = []
# Add current token to suffixes
for suffix in next_suffixes:
suffix[1].append(i)
suffix[2].append(logits)
suffixes.extend(next_suffixes)
return suffixes
def inference_turbo_dfs(model, input_ids, eos_token_id, max_new_tokens, min_prob, min_prob_greedy=1, temperature=1.0, suppress_tokens=[], path=None, attention_mask=None):
"""Perform inference using turbo DFS approach."""
# Convert inputs to JAX arrays if needed
if not isinstance(input_ids, jnp.ndarray):
input_ids = jnp.array(input_ids, dtype=jnp.int32)
# Remove batch dimension if present
if input_ids.ndim == 2:
input_ids = input_ids[0]
assert input_ids.ndim == 1, 'batching not supported'
# Convert min_prob to max_score
max_score = -np.log(min_prob)
max_score_greedy = -np.log(min_prob_greedy) if min_prob_greedy > 0 else float('inf')
max_score_greedy = max(max_score, max_score_greedy)
# Handle path
if path is None:
path = []
if len(path) > 0 and path[-1] == eos_token_id:
path = path[:-1]
# Prepare full path with input_ids + path
if len(path) > 0:
full_path = jnp.concatenate([input_ids, jnp.array(path, dtype=jnp.int32)])
else:
full_path = input_ids
# Get initial logits and cache
full_path_batched = full_path[None, :] # Add batch dimension
model_outputs = model(input_ids=full_path_batched)
logits = model_outputs.logits[0, len(input_ids)-1:]
cache = model_outputs.past_key_values if hasattr(model_outputs, 'past_key_values') else None
# Run turbo DFS
result = turbo_dfs(
model, logits, path, eos_token_id, max_new_tokens,
max_score, max_score_greedy, temperature, suppress_tokens,
score=0.0, pos=len(input_ids), cache=cache
)
# Sort by score
sorted_result = sorted(
[(score_val, np.array(suffix[::-1]), np.array(score_arr[::-1]))
for score_val, suffix, score_arr in result],
key=lambda x: x[0]
)
return sorted_result
def inference_step(tokenized, model, remove_token_type_ids=True, num_beams=1, formatter=None, min_prob=None, current_best=None, **kwargs):
"""Perform a single inference step."""
# Remove token_type_ids if needed
if remove_token_type_ids:
tokenized.pop('token_type_ids', None)
# Convert to JAX arrays
jax_inputs = {k: jnp.array(v) for k, v in tokenized.items()}
if min_prob is not None:
assert num_beams == 1
# Use turbo DFS for inference with min_prob
gen = inference_turbo_dfs(
model,
jax_inputs['input_ids'],
path=current_best,
min_prob=min_prob,
eos_token_id=formatter.tokenizer.eos_token_id,
**kwargs
)
tokens_out = [[g[1] for g in gen]]
scores_out = [[g[2] for g in gen]]
else:
# Use standard generation
input_length = jax_inputs['input_ids'].shape[-1]
# Generate with JAX model
gen_config = {
'max_new_tokens': kwargs.get('max_new_tokens', 20),
'do_sample': kwargs.get('do_sample', False),
'temperature': kwargs.get('temperature', 1.0),
'top_p': kwargs.get('top_p', 1.0),
'top_k': kwargs.get('top_k', 50),
'num_beams': num_beams,
'eos_token_id': formatter.tokenizer.eos_token_id,
'pad_token_id': formatter.tokenizer.pad_token_id,
}
# JAX generation
outputs = model.generate(
**jax_inputs,
**gen_config,
return_dict_in_generate=True,
output_scores=True,
)
# Extract tokens and scores
tokens = outputs.sequences
tokens_out = tokens[:, input_length:][None, :, :]
# Extract scores if available
if hasattr(outputs, 'scores') and outputs.scores:
scores_out = jnp.stack(outputs.scores, axis=1)[None, :, :]
else:
# Create dummy scores if not available
scores_out = jnp.zeros((1, tokens_out.shape[1], tokens_out.shape[2], model.config.vocab_size))
# Convert JAX arrays to numpy
tokens_out = np.array(tokens_out)
scores_out = np.array(scores_out)
return tokens_out, scores_out
def process_inference_output(key, outputs, formatter, store=None, decoder=None, decoder_args={}):
"""Process inference outputs, detokenize, and save results."""
de_tokenized = [formatter.de_tokenize(*output) for output in zip(*outputs)]
inference_save(store, key, de_tokenized)
if decoder is not None:
decoder.process(key, de_tokenized, **decoder_args)
return de_tokenized
def inference_run_v2(model, formatter, dataset, decoder=None, max_new_tokens=None, max_batch_size=1, store=None, result_dict=None, rerun_empty=False, retrain=None, use_turbo=False, group_multi_output=True, **kwargs):
"""Run inference on a dataset with various optimization options."""
assert max_batch_size == 1, 'batch size > 1 not supported in JAX implementation yet'
logger.info('*** Load stored data...')
if result_dict is None:
result_dict = {}
result_dict = inference_load(store, dataset.keys, result_dict)
# Group by base key
by_base_key = {}
needs_rerun = {}
base_key_list = []
for key in dataset.keys:
base_key = key.split('.')[0]
if group_multi_output:
base_key = base_key.split('_')[0]
if base_key not in by_base_key:
base_key_list.append(base_key)
bk_list = by_base_key[base_key] = by_base_key.get(base_key, [])
bk_list.append(key)
# Check which keys need to be rerun
for base_key, keys in by_base_key.items():
for key in keys:
de_tokenized = result_dict.get(key)
if de_tokenized is None or (rerun_empty and not de_tokenized):
bk_list = needs_rerun[base_key] = needs_rerun.get(base_key, [])
bk_list.append(key)
elif decoder is not None:
decoder.process(key, de_tokenized)
# Set up tokenizer padding
formatter.tokenizer.padding_side = 'left'
if max_new_tokens is None:
max_new_tokens = formatter.max_new_tokens()
# Model should be in evaluation mode
model_forward = model
logger.info('*** Start inference run...')
try:
for base_key in tqdm(base_key_list, file=sys.stdout):
run_keys = needs_rerun.get(base_key)
if run_keys:
if retrain is not None:
retrain_dataset = dataset.keep_key_startswith(base_key)
logger.info(f"retraining model for key '{base_key}' (retrain_dataset_size={len(retrain_dataset.keys)})")
retrain(model, retrain_dataset)
for key in run_keys:
input_text = dataset.get(key, formatter)['input']
batch = formatter.tokenizer([input_text], return_tensors='jax')
current_best = decoder.get_current_best(key.split('.')[0]) if use_turbo else None
if current_best is not None:
current_best = dataset.forward_mod(current_best, key)
current_best = formatter.fmt_reply([current_best])
current_best = formatter.tokenizer(input_text+current_best)['input_ids'][batch['input_ids'].shape[-1]:]
batch_out = inference_step(
batch, model_forward, formatter=formatter,
max_new_tokens=max_new_tokens, current_best=current_best, **kwargs
)
outputs = [x[0] for x in batch_out]
result_dict[key] = process_inference_output(
key, outputs, formatter, store=store, decoder=decoder,
decoder_args=dict(print_func=lambda x: logger.info(x))
)
logger.info('*** Completed inference run.')
except KeyboardInterrupt:
logger.info('*** Ctrl+C pressed, stopping inference run.')
return result_dict
class Retrainer(object):
"""Retrainer for test-time training."""
def __init__(self, n, aug_opts, reload_state_dict=None, **kwargs):
self.n = n
self.aug_opts = aug_opts
self.reload_state_dict = reload_state_dict
self.kwargs = kwargs
def preprocess(self, dataset):
"""Preprocess dataset for training."""
ds = [dataset.augment(quiet=True, shfl_keys=True, **self.aug_opts) for _ in range((self.n-1)//dataset.length()+1)]
ds = ds[0] if len(ds)==1 else ds[0].append(*ds[1:])
ds, _ = ds.split_at_pos(self.n)
return ds
def __call__(self, model, dataset):
"""Train model on dataset."""
if self.reload_state_dict is not None:
set_peft_weights(model, self.reload_state_dict)
# Train the model
training_run(model, dataset=self.preprocess(dataset), **self.kwargs)
def calc_score(key, input, reply, formatter, model, store=None, decoder=None, **_):
"""Calculate score for a given input-reply pair."""
# Tokenize input and reply
input_len = len(formatter.tokenizer(input)['input_ids'])
tokenized = formatter.tokenizer([input+reply], return_tensors='jax')
# Get reply tokens
reply_tok = np.array(tokenized['input_ids'][0][input_len:])
# Get logits for scoring
logits = model(
input_ids=tokenized['input_ids'],
attention_mask=tokenized['attention_mask']
).logits[0, input_len-1:-1]
# Convert logits to numpy array
reply_log = np.array(logits)
# Process output
process_inference_output(
key,
(reply_tok[np.newaxis], reply_log[np.newaxis]),
formatter,
store=store,
decoder=decoder
)
def mem_info(device_id=0):
"""Print memory usage information."""
try:
# JAX has a different memory tracking mechanism than PyTorch
devices = jax.devices()
if device_id < len(devices):
device = devices[device_id]
memory_info = jax.device_get(jax.jit(lambda: jax.device_memory_profile(device))())
total_memory = memory_info.get('total_memory', 0) / 1024**3
used_memory = memory_info.get('used_memory', 0) / 1024**3
logger.info(f"*** Device: {device}, used {used_memory:.3} / {total_memory:.3} GB.")
else:
logger.info(f"*** Device ID {device_id} not found among available devices.")
except:
logger.info('*** Exception occurred when getting memory stats.')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment