Created
April 26, 2025 00:40
-
-
Save aug2uag/c8458aa453f7582b0a7fbad1de725ecc to your computer and use it in GitHub Desktop.
JAX model runner
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// model runner | |
""" | |
JAX-based Model Runner for ARC Challenges | |
This module provides JAX implementations of the components needed for test-time training | |
and inference on ARC challenges. | |
""" | |
import json | |
import os, sys | |
import bz2 | |
import pickle | |
import numpy as np | |
import jax | |
import jax.numpy as jnp | |
from jax import lax | |
import flax | |
from flax import linen as nn | |
from flax.training import train_state | |
import optax | |
from typing import Dict, List, Tuple, Optional, Any, Callable, Union | |
import logging | |
from tqdm.auto import tqdm | |
from transformers import AutoTokenizer | |
from transformers.generation import FlaxGenerationMixin | |
from dataclasses import dataclass | |
# Set up logging | |
logger = logging.getLogger(__name__) | |
def indices_required_for_merges(keep_indices, vocab, merges): | |
"""Determine which token indices are required for BPE merges.""" | |
merges_lookup = {} | |
for m in merges: | |
a, b = m.split(' ') if isinstance(m, str) else m | |
key = vocab[f'{a}{b}'] | |
if key not in merges_lookup: merges_lookup[key] = set() | |
merges_lookup[key].add(vocab[a]) | |
merges_lookup[key].add(vocab[b]) | |
to_process = list(keep_indices) | |
while len(to_process): | |
for w in merges_lookup.get(to_process.pop(), []): | |
if w not in keep_indices: | |
keep_indices[w] = None | |
to_process.append(w) | |
return keep_indices | |
def remove_unused_merges(merges, vocab): | |
"""Remove BPE merges that use tokens not in the vocabulary.""" | |
return [f'{a} {b}' for a, b in [m.split(' ') if isinstance(m, str) else m for m in merges] if all(w in vocab for w in [a, b, a + b])] | |
def map_special_tokens(data, mapping=None): | |
"""Map special tokens using the provided mapping.""" | |
tokens = set() | |
if isinstance(data, dict): | |
special = data.get('special_tokens') | |
if special is not None: | |
for v in special.values(): | |
tokens.update(v['ids']) | |
if mapping is not None: | |
v['ids'] = [mapping.get(i) for i in v['ids'] if i in mapping] | |
for v in (data.values() if isinstance(data, dict) else data if isinstance(data, list) else []): | |
tokens.update(map_special_tokens(v, mapping)) | |
return tokens | |
def remove_tokenizer_normalizer(tokenizer): | |
"""Remove normalizer from tokenizer to preserve exact input.""" | |
tokenizer_json = json.loads(tokenizer.backend_tokenizer.to_str()) | |
if tokenizer_json.get('normalizer') is not None: | |
tokenizer_json['normalizer'] = None | |
tokenizer.backend_tokenizer = AutoTokenizer.from_pretrained( | |
None, tokenizer_file=json.dumps(tokenizer_json) | |
).backend_tokenizer | |
return tokenizer | |
def shrink_tokenizer_vocab(tokenizer, keep_indices, keep_special_tokens, keep_token_order): | |
"""Shrink tokenizer vocabulary to only keep necessary tokens.""" | |
tokenizer_json = json.loads(tokenizer.backend_tokenizer.to_str()) | |
assert tokenizer_json['model']['type'] == "BPE" | |
if keep_special_tokens: | |
keep_indices.update({k: None for k in tokenizer.all_special_ids}) | |
keep_indices.update({k: None for k in map_special_tokens(tokenizer_json.get('post_processor'))}) | |
keep_indices = indices_required_for_merges(keep_indices, tokenizer_json['model']['vocab'], tokenizer_json['model']['merges']) | |
if keep_token_order: keep_indices = sorted(keep_indices) | |
mapping = {old: new for new, old in enumerate(keep_indices)} | |
tokenizer_json['model']['vocab'] = {k: mapping[v] for k, v in tokenizer_json['model']['vocab'].items() if v in mapping} | |
tokenizer_json['model']['merges'] = remove_unused_merges(tokenizer_json['model']['merges'], tokenizer_json['model']['vocab']) | |
special_tokens_order = [t['id'] for t in tokenizer_json['added_tokens']] | |
assert special_tokens_order == sorted(special_tokens_order) | |
tokenizer_json['added_tokens'] = sorted([{**t, 'id': mapping[t['id']]} for t in tokenizer_json['added_tokens'] if t['id'] in mapping], key=lambda t: t['id']) | |
map_special_tokens(tokenizer_json.get('post_processor'), mapping) | |
# Create new tokenizer from updated json | |
new_tokenizer = AutoTokenizer.from_pretrained(None, tokenizer_file=json.dumps(tokenizer_json)) | |
return new_tokenizer, mapping, keep_indices | |
def shrink_model_embeddings(model, keep_indices, mapping): | |
"""Adjust model embeddings to fit the shrunk vocabulary.""" | |
config = model.config | |
# Create a new version of embeddings with only the kept indices | |
with jax.default_device(jax.devices("cpu")[0]): | |
params = model.params | |
# Extract original embedding weights | |
embed_weights = params['transformer']['wte']['embedding'] | |
lm_head_weights = params['lm_head']['kernel'] if 'lm_head' in params else None | |
# Create new embedding weights with only the kept indices | |
row_select = jnp.array(list(keep_indices), dtype=jnp.int32) | |
new_embed_weights = embed_weights[row_select] | |
# Update model parameters | |
params['transformer']['wte']['embedding'] = new_embed_weights | |
if lm_head_weights is not None: | |
new_lm_head_weights = lm_head_weights[row_select] | |
params['lm_head']['kernel'] = new_lm_head_weights | |
# Update configurations | |
for config_obj in [config]: | |
for k, v in list(config_obj.to_dict().items()): | |
if k.endswith('token_id'): | |
setattr(config_obj, k, [mapping.get(t) for t in v] if isinstance(v, list) else mapping.get(v)) | |
# Create new model with updated parameters | |
new_model = type(model)(config=config) | |
new_model.params = params | |
return new_model | |
def shrink_embeddings(model, tokenizer, corpus=None, keep_token_ids=[], keep_tokens=[], remove_token_ids=[], keep_model_tokens=True, keep_special_tokens=True, keep_normalizer=False, keep_token_order=True): | |
"""Shrink model embeddings and tokenizer to only keep necessary tokens.""" | |
if not keep_normalizer: | |
tokenizer = remove_tokenizer_normalizer(tokenizer) | |
from collections import OrderedDict # use as OrderedSet | |
keep_indices = OrderedDict() | |
keep_indices.update({k: None for k in keep_token_ids}) | |
keep_indices.update({tokenizer.vocab[t]: None for t in keep_tokens}) | |
if corpus is not None: | |
keep_indices.update({k: None for k in tokenizer(corpus)['input_ids']}) | |
if keep_model_tokens: | |
for config in [model.config]: | |
for k, v in config.to_dict().items(): | |
if k.endswith('token_id'): | |
keep_indices.update({k: None for k in (v if isinstance(v, list) else [v])}) | |
keep_indices.pop(None, None) | |
for idx in remove_token_ids: | |
keep_indices.pop(idx, None) | |
new_tokenizer, mapping, keep_indices = shrink_tokenizer_vocab(tokenizer, keep_indices, keep_special_tokens, keep_token_order) | |
new_model = shrink_model_embeddings(model, keep_indices, mapping=mapping) | |
return new_model, new_tokenizer, mapping | |
def fix_dtypes(model, fix_weights=True, fix_quant_states=True): | |
"""Ensure all model parameters have consistent dtypes.""" | |
# In JAX, we don't need to iterate through modules like in PyTorch | |
# Instead, we create a tree map function to fix dtypes in the parameter tree | |
params = model.params | |
default_dtype = jnp.float32 # or model's default dtype | |
def fix_param_dtype(param): | |
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): | |
if fix_weights and param.dtype != default_dtype: | |
return param.astype(default_dtype) | |
return param | |
# Apply the function to all parameters | |
new_params = jax.tree_map(fix_param_dtype, params) | |
model.params = new_params | |
return model | |
def merge_peft_into_base(model): | |
"""Merge LoRA parameters into base model.""" | |
logger.info('*** Merge peft model into base model...') | |
# Get base params and LoRA params | |
base_params = model.params | |
lora_params = model.lora_params | |
# Create merged parameters | |
merged_params = merge_lora_weights(base_params, lora_params, model.config.lora_config) | |
# Create new model with merged parameters | |
new_model = type(model)(config=model.config) | |
new_model.params = merged_params | |
return fix_dtypes(new_model) | |
def merge_lora_weights(base_params, lora_params, lora_config): | |
"""Merge LoRA weights into base weights.""" | |
merged_params = base_params.copy() | |
for module_name in lora_config.target_modules: | |
# Find all parameters that match this module | |
for path, param in jax.tree_util.tree_leaves_with_path(base_params): | |
path_str = '/'.join([str(p) for p in path]) | |
if module_name in path_str: | |
# Get corresponding LoRA parameters | |
lora_a_path = path + ('lora_A',) | |
lora_b_path = path + ('lora_B',) | |
lora_a = jax.tree_util.tree_get(lora_params, lora_a_path) | |
lora_b = jax.tree_util.tree_get(lora_params, lora_b_path) | |
if lora_a is not None and lora_b is not None: | |
# Compute LoRA update: (B x A) * alpha / r | |
lora_update = (lora_b @ lora_a) * (lora_config.lora_alpha / lora_config.r) | |
# Update base parameter | |
updated_param = param + lora_update | |
# Set updated parameter in merged params | |
merged_params = jax.tree_util.tree_set(merged_params, path, updated_param) | |
return merged_params | |
def save_model(store_path, model=None, tokenizer=None, merge=False): | |
"""Save model and tokenizer to disk.""" | |
if merge and model is not None: | |
model = merge_peft_into_base(model) | |
if store_path is not None: | |
assert model is not None or tokenizer is not None | |
logger.info(f"*** Saving{' merged' if merge else ''} model/tokenizer to '{store_path}'...") | |
os.makedirs(store_path, exist_ok=True) | |
if model is not None: | |
# Save model config and parameters | |
with open(os.path.join(store_path, 'config.json'), 'w') as f: | |
json.dump(model.config.to_dict(), f) | |
# Save model parameters | |
with open(os.path.join(store_path, 'model.safetensors'), 'wb') as f: | |
params_bytes = flax.serialization.msgpack_serialize(model.params) | |
f.write(params_bytes) | |
if tokenizer is not None: | |
tokenizer.save_pretrained(store_path) | |
# Remove unnecessary tokenizer model file if it exists | |
to_delete = os.path.join(store_path, 'tokenizer.model') | |
if os.path.isfile(to_delete): | |
os.remove(to_delete) | |
return model | |
def is_unsloth_model(model): | |
"""Check if model is an unsloth model.""" | |
# JAX doesn't have unsloth models, but we'll keep the function for compatibility | |
return False | |
def is_peft_model(model): | |
"""Check if model is a PEFT model.""" | |
return hasattr(model, 'lora_params') | |
def download_model(repo_id, store_path, get_name=lambda n: os.path.join(n.replace('/', '--'), 'transformers', 'default', '1')): | |
"""Download a model from Hugging Face or use local cache.""" | |
import os | |
if os.path.exists(repo_id): | |
return repo_id | |
model_path = os.path.join(store_path, get_name(repo_id)) | |
if not os.path.exists(model_path): | |
from huggingface_hub import snapshot_download | |
download_path = snapshot_download(repo_id=repo_id) | |
os.makedirs(os.path.split(model_path)[0], exist_ok=True) | |
os.symlink(download_path, model_path, target_is_directory=True) | |
return model_path | |
def get_and_fix_peft_weights(store): | |
"""Load and fix PEFT state weights.""" | |
logger.info(f"*** Load peft state_dict from '{store}'...") | |
# Load PEFT weights | |
with open(os.path.join(store, 'lora_params.msgpack'), 'rb') as f: | |
state_dict = flax.serialization.msgpack_deserialize(f.read()) | |
# Filter out unwanted keys | |
filtered_dict = {} | |
for k, v in state_dict.items(): | |
if 'modules_to_save' not in k: | |
filtered_dict[k] = v | |
return filtered_dict | |
def set_peft_weights(model, state_dict): | |
"""Set PEFT weights in the model.""" | |
logger.info(f"*** Set model state_dict...") | |
# Set lora parameters | |
model.lora_params = state_dict | |
return model | |
class LoraConfig: | |
"""Configuration for LoRA.""" | |
def __init__(self, r=8, target_modules=None, lora_alpha=16, lora_dropout=0.0, | |
bias="none", use_gradient_checkpointing=False, random_state=42, | |
use_rslora=False, loftq_config=None): | |
self.r = r | |
self.target_modules = target_modules or [] | |
self.lora_alpha = lora_alpha | |
self.lora_dropout = lora_dropout | |
self.bias = bias | |
self.use_gradient_checkpointing = use_gradient_checkpointing | |
self.random_state = random_state | |
self.use_rslora = use_rslora | |
self.loftq_config = loftq_config | |
def apply_lora_to_model(model, lora_config): | |
"""Apply LoRA to model parameters.""" | |
# Set up random key | |
rng = jax.random.PRNGKey(lora_config.random_state) | |
# Initialize LoRA parameters | |
lora_params = {} | |
for module_name in lora_config.target_modules: | |
# Find all parameters that match this module | |
for path, param in jax.tree_util.tree_leaves_with_path(model.params): | |
path_str = '/'.join([str(p) for p in path]) | |
if module_name in path_str and isinstance(param, jnp.ndarray): | |
# For weight matrices, add LoRA parameters | |
if len(param.shape) == 2: # Weight matrix | |
# Initialize A and B matrices | |
rng, key_a, key_b = jax.random.split(rng, 3) | |
if lora_config.use_rslora: | |
# RS-LoRA initialization | |
lora_a = jax.random.normal(key_a, (lora_config.r, param.shape[1])) * 0.1 | |
lora_b = jax.random.normal(key_b, (param.shape[0], lora_config.r)) * 0.1 | |
else: | |
# Standard LoRA initialization | |
lora_a = jax.random.normal(key_a, (lora_config.r, param.shape[1])) * 0.01 | |
lora_b = jnp.zeros((param.shape[0], lora_config.r)) | |
# Add to lora_params | |
lora_params = jax.tree_util.tree_set(lora_params, path + ('lora_A',), lora_a) | |
lora_params = jax.tree_util.tree_set(lora_params, path + ('lora_B',), lora_b) | |
# Add lora_params to model | |
model.lora_params = lora_params | |
model.config.lora_config = lora_config | |
return model | |
def prepare_model(model, mode, tokenizer=None, formatter=None, shrink_embedding=False, | |
dequantize=False, peft=[], local_files_only=False, add_special_tokens={}, | |
set_pad_token=None, keep_tokens=[], keep_normalizer=None, | |
peft_trainable=True, device_map=None, tf_grad_cp=True, tf_use_fa2=True, **kwargs): | |
"""Prepare model and tokenizer with various optimizations.""" | |
if isinstance(model, str): | |
assert tokenizer is None | |
logger.info(f"*** Load base model and tokenizer from '{model}'...") | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model, local_files_only=local_files_only, **kwargs) | |
# Load model based on mode | |
if mode != 'tokenizer_only': | |
from transformers import FlaxAutoModelForCausalLM | |
# JAX doesn't support device mapping like PyTorch, ignore device_map | |
model_load_args = {} | |
# Load model | |
model = FlaxAutoModelForCausalLM.from_pretrained(model, **model_load_args) | |
else: | |
model = None | |
# Add special tokens if needed | |
if add_special_tokens: | |
tokenizer.add_special_tokens(add_special_tokens) | |
# Set pad token if needed | |
if set_pad_token is not None: | |
tokenizer.pad_token = set_pad_token | |
# Set formatter | |
if formatter is not None and not hasattr(formatter, 'corpus'): | |
formatter = formatter(tokenizer=tokenizer) | |
# Apply embedding shrinking if needed | |
if (shrink_embedding < len(tokenizer.vocab) if type(shrink_embedding) == int else shrink_embedding) or keep_normalizer is False: | |
logger.info('*** Shrink embedding...') | |
embedding_size_before_shrink = len(tokenizer.vocab) | |
model, tokenizer, mapping = shrink_embeddings( | |
model, tokenizer, formatter.get_corpus(), | |
keep_tokens=keep_tokens, keep_normalizer=keep_normalizer | |
) | |
logger.info(f'*** -> Reduced embedding size from {embedding_size_before_shrink} to {len(mapping)} words.') | |
# Apply LoRA or other parameter-efficient methods | |
if len(peft): | |
peft_trained = True if is_peft_model(model) else None | |
for i, m in enumerate(peft): | |
if peft_trained is True: | |
model, peft_trained = merge_peft_into_base(model), None | |
if isinstance(m, str): | |
if peft_trained is False: | |
_, peft_trained = load_peft_state(model, m), True | |
else: | |
logger.info(f"*** Load peft model from '{m}'...") | |
state_dict = get_and_fix_peft_weights(m) | |
model = set_peft_weights(model, state_dict) | |
peft_trained = True | |
else: | |
assert peft_trained is None | |
if isinstance(m, dict): | |
logger.info('*** Create new peft model...') | |
lora_config = LoraConfig(**m) | |
model = apply_lora_to_model(model, lora_config) | |
peft_trained = False | |
else: | |
assert m is None | |
return model, tokenizer, formatter | |
@dataclass | |
class TrainingConfig: | |
"""Configuration for training.""" | |
batch_size: int = 1 | |
learning_rate: float = 1e-4 | |
weight_decay: float = 0.01 | |
num_epochs: int = 1 | |
warmup_steps: int = 100 | |
max_steps: int = -1 | |
seed: int = 42 | |
logging_steps: int = 10 | |
save_strategy: str = "no" | |
output_dir: str = "./output" | |
gradient_accumulation_steps: int = 1 | |
embedding_learning_rate: float = 1e-5 | |
lr_scheduler_type: str = "linear" | |
optim: str = "adamw" | |
def create_optimizer(config): | |
"""Create optimizer based on configuration.""" | |
if config.lr_scheduler_type == "linear": | |
schedule_fn = optax.linear_schedule( | |
init_value=config.learning_rate, | |
end_value=0, | |
transition_steps=config.max_steps, | |
transition_begin=config.warmup_steps | |
) | |
elif config.lr_scheduler_type == "cosine": | |
schedule_fn = optax.warmup_cosine_decay_schedule( | |
init_value=0.0, | |
peak_value=config.learning_rate, | |
warmup_steps=config.warmup_steps, | |
decay_steps=config.max_steps, | |
end_value=0.0 | |
) | |
else: # constant | |
schedule_fn = config.learning_rate | |
# Create optimizer with weight decay | |
optimizer = optax.adamw( | |
learning_rate=schedule_fn, | |
weight_decay=config.weight_decay, | |
b1=0.9, | |
b2=0.999, | |
eps=1e-8 | |
) | |
return optimizer | |
def training_run(model, formatter, dataset, train_args, max_seq_length, merge=False, | |
store=None, packing=False, grad_acc_fix=False, optimizers=None): | |
"""Run training on the model.""" | |
assert merge is False, "Merge after training not supported in JAX implementation" | |
# Create training config | |
config = TrainingConfig(**train_args) | |
# Prepare dataset | |
formatter.tokenizer.padding_side = 'right' | |
train_data = dataset.as_list(formatter) | |
# Create optimizer | |
if optimizers is None: | |
optimizer = create_optimizer(config) | |
else: | |
optimizer = optimizers[0] | |
# Create training state | |
rng = jax.random.PRNGKey(config.seed) | |
# Define training step | |
def train_step(state, batch, rng): | |
"""Single training step.""" | |
def loss_fn(params): | |
logits = model.apply( | |
{"params": params}, | |
input_ids=batch["input_ids"], | |
attention_mask=batch["attention_mask"], | |
deterministic=False, | |
rngs={"dropout": rng} | |
).logits | |
# Shift logits and labels for next token prediction | |
shift_logits = logits[:, :-1] | |
shift_labels = batch["labels"][:, 1:] | |
# Calculate loss (with label smoothing if specified) | |
loss = optax.softmax_cross_entropy( | |
shift_logits, | |
jax.nn.one_hot(shift_labels, shift_logits.shape[-1]) | |
) | |
# Mask out padding tokens | |
mask = (shift_labels != -100) | |
loss = (loss * mask).sum() / mask.sum() | |
return loss | |
# Get gradients | |
loss, grads = jax.value_and_grad(loss_fn)(state.params) | |
# Update parameters | |
new_state = state.apply_gradients(grads=grads) | |
return new_state, loss | |
# Create initial training state | |
state = train_state.TrainState.create( | |
apply_fn=model.apply, | |
params=model.params, | |
tx=optimizer | |
) | |
# JIT compile the training step | |
jit_train_step = jax.jit(train_step) | |
# Training loop | |
rng, train_rng = jax.random.split(rng) | |
logger.info("*** Start training run...") | |
# Calculate number of training steps | |
num_examples = len(train_data) | |
steps_per_epoch = num_examples // config.batch_size | |
total_steps = steps_per_epoch * config.num_epochs | |
if config.max_steps > 0: | |
total_steps = min(total_steps, config.max_steps) | |
# Training loop | |
step = 0 | |
epoch = 0 | |
with tqdm(total=total_steps, desc="Training") as progress_bar: | |
while step < total_steps: | |
# Shuffle data for new epoch | |
if step % steps_per_epoch == 0: | |
rng, data_rng = jax.random.split(rng) | |
shuffled_idx = jax.random.permutation(data_rng, jnp.arange(num_examples)) | |
train_data = [train_data[i] for i in shuffled_idx] | |
epoch += 1 | |
# Get batch | |
batch_start = (step % steps_per_epoch) * config.batch_size | |
batch_end = min(batch_start + config.batch_size, num_examples) | |
batch_data = train_data[batch_start:batch_end] | |
# Prepare batch | |
input_texts = [item["text"] for item in batch_data] | |
encodings = formatter.tokenizer( | |
input_texts, | |
max_length=max_seq_length, | |
padding="max_length", | |
truncation=True, | |
return_tensors="jax" | |
) | |
# Prepare labels | |
input_ids = encodings["input_ids"] | |
labels = jnp.where( | |
input_ids == formatter.tokenizer.pad_token_id, | |
-100, # Ignore padding tokens in loss | |
input_ids | |
) | |
# Modify labels based on formatter's data collator if needed | |
collator = formatter.get_data_collator() | |
if collator is not None: | |
# Implement collator logic specific to your formatter | |
pass | |
# Prepare batch for JAX | |
batch = { | |
"input_ids": input_ids, | |
"attention_mask": encodings["attention_mask"], | |
"labels": labels | |
} | |
# Train step | |
rng, step_rng = jax.random.split(rng) | |
state, loss = jit_train_step(state, batch, step_rng) | |
if step % config.logging_steps == 0: | |
logger.info(f"Step {step}: loss = {loss.item():.4f}") | |
step += 1 | |
progress_bar.update(1) | |
progress_bar.set_postfix({"loss": f"{loss.item():.4f}"}) | |
# Update model parameters | |
model.params = state.params | |
# Save model if needed | |
if store is not None: | |
save_model(store, model, formatter.tokenizer, merge=merge) | |
logger.info("*** Training completed.") | |
return model, {"metrics": {"train_runtime": None}} # For compatibility | |
def inference_load(store, keys=True, result_dict=None, always_read_from_file=False): | |
"""Load inference results from store.""" | |
if result_dict is None: | |
result_dict = {} | |
if store is not None: | |
if keys is True: | |
keys = os.listdir(store) | |
for key in keys: | |
if always_read_from_file or key not in result_dict: | |
try: | |
with bz2.BZ2File(os.path.join(store, key)) as f: | |
result_dict[key] = pickle.load(f) | |
except: | |
continue | |
return result_dict | |
def inference_save(store, key, outputs): | |
"""Save inference results to store.""" | |
if store is not None: | |
os.makedirs(store, exist_ok=True) | |
with bz2.BZ2File(os.path.join(store, key), 'w') as f: | |
pickle.dump(outputs, f) | |
class Decoder: | |
"""Decoder for model outputs.""" | |
def __init__(self, formatter, dataset, n_guesses, max_outputs=None, frac_score=False, | |
quiet=False, name='', additional_decoders=None, prob_baseline=None): | |
self.formatter = formatter | |
self.dataset = dataset | |
self.n_guesses = n_guesses | |
self.decoded_results = {} | |
self.correct_solutions = {} | |
self.keys_lim = set() | |
self.keys_all = set() | |
self.mult_cnt = {} | |
self.keys_cnt = {} | |
self.frac_score = frac_score | |
self.max_outputs = max_outputs | |
self.quiet = quiet | |
self.input_len = [{} if formatter is not None and formatter.tokenizer is None else ds.get_lengths(formatter, name='input') for ds in [dataset, dataset.mod(np.transpose, keep_key=True)]] | |
self.reply_len = [{} if formatter is not None and formatter.tokenizer is None else ds.get_lengths(formatter, name='reply') for ds in [dataset, dataset.mod(np.transpose, keep_key=True)]] | |
self.additional_decoders = additional_decoders | |
self.name = name | |
self.prob_tracker = {} | |
self.prob_tracker_best = {} | |
self.prob_baseline = prob_baseline | |
def score(self, *to_score): | |
"""Score the decoded results.""" | |
scores = [(sum(1/self.mult_cnt[k.split('_')[0]] for k in s) if self.frac_score else len(s)) for s in to_score] | |
score_cnt = len(self.mult_cnt if self.frac_score else self.keys_cnt) | |
return scores, score_cnt | |
def from_store(self, store, **kwargs): | |
"""Load results from store.""" | |
for key, outputs in inference_load(store).items(): | |
self.process(key, outputs, **kwargs) | |
return self | |
def score_fmt(self, v): | |
"""Format score for display.""" | |
return f'{v:5.1f}' if self.frac_score else f'{v:3}' | |
def process_single_output(self, key, output_len, decoded, print_func=print, len_info=None, device_info=None): | |
"""Process a single output from the model.""" | |
inv_mod = {k: v if k.endswith('val') else self.dataset.invert_mod(v, key, inv_perm=(k.startswith('output') or k.startswith('score_all'))) for k, v in decoded.items()} | |
base_key = key.split('.')[0] | |
self.decoded_results[base_key] = self.decoded_results.get(base_key, {}) | |
self.decoded_results[base_key][key] = inv_mod | |
output = inv_mod.get('output') | |
score = inv_mod.get('score') | |
# quick scoring | |
self.keys_cnt[base_key] = self.keys_cnt.get(base_key, 0) + 1 | |
mult_key, mult_sub = (base_key.split('_') + ['0'])[:2] | |
self.mult_cnt[mult_key] = max(self.mult_cnt.get(mult_key, 0), int(mult_sub) + 1) | |
if len(self.dataset.replies): | |
correct_solution = self.dataset.replies.get(base_key) | |
if correct_solution is not None: | |
correct_solution = correct_solution[0] | |
self.correct_solutions[base_key] = correct_solution | |
is_correct = correct_solution is not None and np.array_equal(correct_solution, output) | |
if is_correct: | |
self.keys_all.add(base_key) | |
if self.keys_cnt[base_key] <= self.n_guesses: self.keys_lim.add(base_key) | |
corr_str = 'cant_decode' if output is None else 'sol_unknown' if correct_solution is None else 'ALL_CORRECT' if is_correct else 'bad_xy_size' if np.shape(correct_solution)!=np.shape(output) else 'bad_content' | |
(score_lim, score_all), score_cnt = self.score(self.keys_lim, self.keys_all) | |
tp_arr = (key.count('transpose') + key.count('rot90')) % 2 | |
msc = None if score is None else np.sum(score) | |
fsc = inv_mod.get('score_val') | |
if output is not None and fsc is not None: | |
pt = self.prob_tracker[base_key] = self.prob_tracker.get(base_key, {}) | |
hash = tuple(map(tuple, output)) | |
prob = pt[hash] = pt.get(hash, 0) + (np.exp(fsc) if self.prob_baseline is None else fsc - np.log(self.prob_baseline)) | |
current_best = self.prob_tracker_best.get(base_key) | |
if current_best is None or current_best[0]<prob: | |
self.prob_tracker_best[base_key] = (prob, output) | |
fmt_name = f'{self.name}: ' if self.name else '' | |
msc_print = f'{min(-msc, 9.99999):7.5f}' if msc is not None else 'unknown' | |
fsc_print = f'{min(-fsc, 9.99999):7.5f}' if fsc is not None else 'unknown' | |
if not self.quiet: print_func(f" {fmt_name}acc: {self.score_fmt(score_lim)}/{score_cnt:3}={min(score_lim/score_cnt, 0.999):5.1%} (2-guess), {self.score_fmt(score_all)}/{score_cnt:3}={min(score_all/score_cnt, 0.999):5.1%} (any);{f' {device_info}' if device_info else ''} tok:{self.input_len[tp_arr].get(base_key, '?'):>4}+{self.reply_len[tp_arr].get(base_key, '?'):>3}>{'n/a' if output_len is None else output_len:>3} {corr_str}:{msc_print}|{fsc_print} [{key}]") | |
def get_current_best(self, base_key): | |
"""Get the current best output for a key.""" | |
current_best = self.prob_tracker_best.get(base_key) | |
return None if current_best is None else current_best[1] | |
def process_single_decode(self, key, de_tokenized, print_func=print, **kwargs): | |
"""Process single decoded output.""" | |
if len(de_tokenized)==3 and not isinstance(de_tokenized[1], float): # for backwards compatibility | |
output_len, *data = de_tokenized | |
score_val = None | |
else: output_len, score_val, *data = de_tokenized | |
if self.formatter is None: | |
assert len(data) == 1 | |
decoded = [data[0]] | |
else: decoded = self.formatter.decode_to_array(*data) | |
for d in decoded: d['score_val'] = score_val | |
for i, dec in enumerate(decoded): | |
if i==0: self.process_single_output(key, output_len, dec, print_func=print_func, **kwargs) | |
elif self.additional_decoders: | |
if i-1<len(self.additional_decoders): self.additional_decoders[i-1].process_single_output(key, output_len, dec, print_func=print_func, **kwargs) | |
else: print_func(f'{key} no decoder available for output #{i}') | |
else: self.process_single_output(f'{key}.fix{i}', output_len, dec, print_func=print_func, **kwargs) | |
def process(self, key, de_tokenized, **kwargs): | |
"""Process all decoded outputs for a key.""" | |
for i, d in enumerate(de_tokenized): | |
if self.max_outputs is None or i<=self.max_outputs: | |
self.process_single_decode(f'{key}.out{i}', d, **kwargs) | |
def get_unsolved_keys(self): | |
"""Get keys that haven't been solved yet.""" | |
unsolved = [] | |
for base_key, reply in self.dataset.replies.items(): | |
if not any(np.array_equal(reply[0], s.get('output')) for s in self.decoded_results.get(base_key, {}).values()): | |
unsolved.append(base_key) | |
return unsolved | |
def run_selection_algo(self, selection_algorithm): | |
"""Run a selection algorithm on the decoded results.""" | |
return {bk: (selection_algorithm({k: g for k, g in v.items() if g.get('output') is not None}) if any(g.get('output') is not None for g in v.values()) else []) for bk, v in self.decoded_results.items()} | |
def benchmark_selection_algos(self, selection_algorithms, skip_failed=True): | |
"""Benchmark various selection algorithms.""" | |
import numpy as np | |
results = {} | |
logger.info('*** Benchmark selection algorithms...') | |
for selection_algorithm in selection_algorithms: | |
name = selection_algorithm.__name__ | |
try: | |
selected = self.run_selection_algo(selection_algorithm) | |
if self.formatter is not None: | |
for sols in selected.values(): | |
for s in sols: | |
assert self.formatter.is_valid_solution(s), f'found invalid solutions {s}' | |
correct_keys = {k for k, v in selected.items() if self.correct_solutions.get(k) is not None and any(np.array_equal(guess, self.correct_solutions[k]) for guess in v[:self.n_guesses])} | |
(score,), score_cnt = self.score(correct_keys) | |
results[name] = score | |
logger.info(f" acc: {score:5.1f}/{score_cnt:3}={score/score_cnt:6.2%} ('{name}')") | |
except: | |
logger.info(f" {'execution failed':>21} ('{name}')") | |
if not skip_failed: raise | |
return results | |
def calc_augmented_scores(self, model, base_keys=None, store=None, seed=0, max_len=None, make_unique=True, quiet=False, **kwargs): | |
"""Calculate augmented scores for outputs.""" | |
if base_keys is None: base_keys = list(self.decoded_results.keys()) | |
if store is not None: store = f'{store}_new' # new format is not backwards compatible, so use new folder | |
iterator = base_keys if quiet else tqdm(base_keys, desc='calculate augmented scores', file=sys.stdout) | |
for bk in iterator: | |
res = self.decoded_results.get(bk, {}) | |
known_scores = {} | |
for k, v in sorted(res.items()): | |
if 'output' in v: | |
k_store = None if store is None else os.path.join(store, k) | |
id = tuple(map(tuple, v['output'])) | |
if not (make_unique and id in known_scores): | |
try: | |
assert k_store is not None | |
with bz2.BZ2File(k_store) as f: known_scores[id] = pickle.load(f) | |
if isinstance(known_scores[id], list): known_scores[id] = dict(score_multi=known_scores[id]) # for backwards compatibility | |
k_store = None | |
except: | |
temp_dataset = self.dataset.__class__( | |
keys=[bk], | |
queries={bk: self.dataset.queries.get(bk)}, | |
replies={bk: [v['output'].tolist()]}, | |
) | |
temp_decoder = self.__class__(self.formatter, temp_dataset, n_guesses=self.n_guesses, quiet=True) | |
temp_dataset = temp_dataset.augment(**kwargs, seed=(seed+hash(k)+hash(id)) % 1024**2, quiet=True) | |
if max_len is not None: temp_dataset = temp_dataset.cut_to_len(formatter=self.formatter, name='input', max_len=max_len, quiet=True) | |
for x in temp_dataset.as_list(self.formatter): calc_score(**x, formatter=self.formatter, model=model, decoder=temp_decoder) | |
known_scores[id] = dict( | |
score_multi=[np.sum(x['score']) for x in temp_decoder.decoded_results[bk].values()], | |
score_multi_nl=[x['score_val'] for x in temp_decoder.decoded_results[bk].values()], | |
score_multi_array=np.array([x['score'] for x in temp_decoder.decoded_results[bk].values()]), | |
score_multi_array_cum=np.array([x['score_cum'] for x in temp_decoder.decoded_results[bk].values()]), | |
score_multi_array_all=np.array([x['score_all'] for x in temp_decoder.decoded_results[bk].values()]), | |
score_multi_array_all_cum=np.array([x['score_all_cum'] for x in temp_decoder.decoded_results[bk].values()]), | |
) | |
if k_store is not None: | |
os.makedirs(store, exist_ok=True) | |
with bz2.BZ2File(k_store, 'w') as f: pickle.dump(known_scores[id], f) | |
v.update(known_scores[id]) | |
def turbo_dfs(model, logits, path, eos_token_id, max_new_tokens, max_score, max_score_greedy, temperature, suppress_tokens, score=0.0, pos=0, cache=None): | |
"""Depth-first search for generating sequences in a compute-efficient manner.""" | |
# Convert JAX arrays to numpy for compatibility | |
logits = logits.at[0, suppress_tokens].set(-1e10) # Suppress tokens | |
# Calculate negative log likelihood | |
if temperature > 0: | |
logits = logits / temperature | |
log_probs = jax.nn.log_softmax(logits, axis=-1) | |
nll = -np.array(log_probs) | |
# Find greedy index | |
greedy_index = int(np.argmin(nll)) | |
# Convert to list of tuples for sorting | |
nll_list = [(i, float(nll[i])) for i in range(nll.shape[0])] | |
# Follow precomputed path first if available | |
if path and len(path) > 0: | |
first_token = path[0] | |
# Swap first token with the path token | |
for i, (idx, _) in enumerate(nll_list): | |
if idx == first_token: | |
nll_list[0], nll_list[i] = nll_list[i], nll_list[0] | |
path = path[1:] | |
break | |
suffixes = [] | |
for i, s in nll_list: | |
next_score = score + s | |
allowed_max_score = max_score_greedy if i == greedy_index else max_score | |
if next_score < allowed_max_score: | |
if i == eos_token_id: | |
next_suffixes = [(next_score, [], [])] | |
elif max_new_tokens > 1: | |
# Get next token logits | |
next_input_ids = jnp.array([[i]]) | |
next_position_ids = jnp.array([[pos]]) | |
# Use cache for efficient generation | |
if cache is not None: | |
next_logits, next_cache = model( | |
input_ids=next_input_ids, | |
position_ids=next_position_ids, | |
past_key_values=cache | |
) | |
next_suffixes = turbo_dfs( | |
model, next_logits[0], path, eos_token_id, max_new_tokens-1, | |
max_score, allowed_max_score, temperature, suppress_tokens, | |
score=next_score, pos=pos+1, cache=next_cache | |
) | |
else: | |
# No cache available, use standard forward pass | |
next_logits = model( | |
input_ids=next_input_ids, | |
position_ids=next_position_ids, | |
).logits[0] | |
next_suffixes = turbo_dfs( | |
model, next_logits, path, eos_token_id, max_new_tokens-1, | |
max_score, allowed_max_score, temperature, suppress_tokens, | |
score=next_score, pos=pos+1, cache=None | |
) | |
else: | |
next_suffixes = [] | |
# Add current token to suffixes | |
for suffix in next_suffixes: | |
suffix[1].append(i) | |
suffix[2].append(logits) | |
suffixes.extend(next_suffixes) | |
return suffixes | |
def inference_turbo_dfs(model, input_ids, eos_token_id, max_new_tokens, min_prob, min_prob_greedy=1, temperature=1.0, suppress_tokens=[], path=None, attention_mask=None): | |
"""Perform inference using turbo DFS approach.""" | |
# Convert inputs to JAX arrays if needed | |
if not isinstance(input_ids, jnp.ndarray): | |
input_ids = jnp.array(input_ids, dtype=jnp.int32) | |
# Remove batch dimension if present | |
if input_ids.ndim == 2: | |
input_ids = input_ids[0] | |
assert input_ids.ndim == 1, 'batching not supported' | |
# Convert min_prob to max_score | |
max_score = -np.log(min_prob) | |
max_score_greedy = -np.log(min_prob_greedy) if min_prob_greedy > 0 else float('inf') | |
max_score_greedy = max(max_score, max_score_greedy) | |
# Handle path | |
if path is None: | |
path = [] | |
if len(path) > 0 and path[-1] == eos_token_id: | |
path = path[:-1] | |
# Prepare full path with input_ids + path | |
if len(path) > 0: | |
full_path = jnp.concatenate([input_ids, jnp.array(path, dtype=jnp.int32)]) | |
else: | |
full_path = input_ids | |
# Get initial logits and cache | |
full_path_batched = full_path[None, :] # Add batch dimension | |
model_outputs = model(input_ids=full_path_batched) | |
logits = model_outputs.logits[0, len(input_ids)-1:] | |
cache = model_outputs.past_key_values if hasattr(model_outputs, 'past_key_values') else None | |
# Run turbo DFS | |
result = turbo_dfs( | |
model, logits, path, eos_token_id, max_new_tokens, | |
max_score, max_score_greedy, temperature, suppress_tokens, | |
score=0.0, pos=len(input_ids), cache=cache | |
) | |
# Sort by score | |
sorted_result = sorted( | |
[(score_val, np.array(suffix[::-1]), np.array(score_arr[::-1])) | |
for score_val, suffix, score_arr in result], | |
key=lambda x: x[0] | |
) | |
return sorted_result | |
def inference_step(tokenized, model, remove_token_type_ids=True, num_beams=1, formatter=None, min_prob=None, current_best=None, **kwargs): | |
"""Perform a single inference step.""" | |
# Remove token_type_ids if needed | |
if remove_token_type_ids: | |
tokenized.pop('token_type_ids', None) | |
# Convert to JAX arrays | |
jax_inputs = {k: jnp.array(v) for k, v in tokenized.items()} | |
if min_prob is not None: | |
assert num_beams == 1 | |
# Use turbo DFS for inference with min_prob | |
gen = inference_turbo_dfs( | |
model, | |
jax_inputs['input_ids'], | |
path=current_best, | |
min_prob=min_prob, | |
eos_token_id=formatter.tokenizer.eos_token_id, | |
**kwargs | |
) | |
tokens_out = [[g[1] for g in gen]] | |
scores_out = [[g[2] for g in gen]] | |
else: | |
# Use standard generation | |
input_length = jax_inputs['input_ids'].shape[-1] | |
# Generate with JAX model | |
gen_config = { | |
'max_new_tokens': kwargs.get('max_new_tokens', 20), | |
'do_sample': kwargs.get('do_sample', False), | |
'temperature': kwargs.get('temperature', 1.0), | |
'top_p': kwargs.get('top_p', 1.0), | |
'top_k': kwargs.get('top_k', 50), | |
'num_beams': num_beams, | |
'eos_token_id': formatter.tokenizer.eos_token_id, | |
'pad_token_id': formatter.tokenizer.pad_token_id, | |
} | |
# JAX generation | |
outputs = model.generate( | |
**jax_inputs, | |
**gen_config, | |
return_dict_in_generate=True, | |
output_scores=True, | |
) | |
# Extract tokens and scores | |
tokens = outputs.sequences | |
tokens_out = tokens[:, input_length:][None, :, :] | |
# Extract scores if available | |
if hasattr(outputs, 'scores') and outputs.scores: | |
scores_out = jnp.stack(outputs.scores, axis=1)[None, :, :] | |
else: | |
# Create dummy scores if not available | |
scores_out = jnp.zeros((1, tokens_out.shape[1], tokens_out.shape[2], model.config.vocab_size)) | |
# Convert JAX arrays to numpy | |
tokens_out = np.array(tokens_out) | |
scores_out = np.array(scores_out) | |
return tokens_out, scores_out | |
def process_inference_output(key, outputs, formatter, store=None, decoder=None, decoder_args={}): | |
"""Process inference outputs, detokenize, and save results.""" | |
de_tokenized = [formatter.de_tokenize(*output) for output in zip(*outputs)] | |
inference_save(store, key, de_tokenized) | |
if decoder is not None: | |
decoder.process(key, de_tokenized, **decoder_args) | |
return de_tokenized | |
def inference_run_v2(model, formatter, dataset, decoder=None, max_new_tokens=None, max_batch_size=1, store=None, result_dict=None, rerun_empty=False, retrain=None, use_turbo=False, group_multi_output=True, **kwargs): | |
"""Run inference on a dataset with various optimization options.""" | |
assert max_batch_size == 1, 'batch size > 1 not supported in JAX implementation yet' | |
logger.info('*** Load stored data...') | |
if result_dict is None: | |
result_dict = {} | |
result_dict = inference_load(store, dataset.keys, result_dict) | |
# Group by base key | |
by_base_key = {} | |
needs_rerun = {} | |
base_key_list = [] | |
for key in dataset.keys: | |
base_key = key.split('.')[0] | |
if group_multi_output: | |
base_key = base_key.split('_')[0] | |
if base_key not in by_base_key: | |
base_key_list.append(base_key) | |
bk_list = by_base_key[base_key] = by_base_key.get(base_key, []) | |
bk_list.append(key) | |
# Check which keys need to be rerun | |
for base_key, keys in by_base_key.items(): | |
for key in keys: | |
de_tokenized = result_dict.get(key) | |
if de_tokenized is None or (rerun_empty and not de_tokenized): | |
bk_list = needs_rerun[base_key] = needs_rerun.get(base_key, []) | |
bk_list.append(key) | |
elif decoder is not None: | |
decoder.process(key, de_tokenized) | |
# Set up tokenizer padding | |
formatter.tokenizer.padding_side = 'left' | |
if max_new_tokens is None: | |
max_new_tokens = formatter.max_new_tokens() | |
# Model should be in evaluation mode | |
model_forward = model | |
logger.info('*** Start inference run...') | |
try: | |
for base_key in tqdm(base_key_list, file=sys.stdout): | |
run_keys = needs_rerun.get(base_key) | |
if run_keys: | |
if retrain is not None: | |
retrain_dataset = dataset.keep_key_startswith(base_key) | |
logger.info(f"retraining model for key '{base_key}' (retrain_dataset_size={len(retrain_dataset.keys)})") | |
retrain(model, retrain_dataset) | |
for key in run_keys: | |
input_text = dataset.get(key, formatter)['input'] | |
batch = formatter.tokenizer([input_text], return_tensors='jax') | |
current_best = decoder.get_current_best(key.split('.')[0]) if use_turbo else None | |
if current_best is not None: | |
current_best = dataset.forward_mod(current_best, key) | |
current_best = formatter.fmt_reply([current_best]) | |
current_best = formatter.tokenizer(input_text+current_best)['input_ids'][batch['input_ids'].shape[-1]:] | |
batch_out = inference_step( | |
batch, model_forward, formatter=formatter, | |
max_new_tokens=max_new_tokens, current_best=current_best, **kwargs | |
) | |
outputs = [x[0] for x in batch_out] | |
result_dict[key] = process_inference_output( | |
key, outputs, formatter, store=store, decoder=decoder, | |
decoder_args=dict(print_func=lambda x: logger.info(x)) | |
) | |
logger.info('*** Completed inference run.') | |
except KeyboardInterrupt: | |
logger.info('*** Ctrl+C pressed, stopping inference run.') | |
return result_dict | |
class Retrainer(object): | |
"""Retrainer for test-time training.""" | |
def __init__(self, n, aug_opts, reload_state_dict=None, **kwargs): | |
self.n = n | |
self.aug_opts = aug_opts | |
self.reload_state_dict = reload_state_dict | |
self.kwargs = kwargs | |
def preprocess(self, dataset): | |
"""Preprocess dataset for training.""" | |
ds = [dataset.augment(quiet=True, shfl_keys=True, **self.aug_opts) for _ in range((self.n-1)//dataset.length()+1)] | |
ds = ds[0] if len(ds)==1 else ds[0].append(*ds[1:]) | |
ds, _ = ds.split_at_pos(self.n) | |
return ds | |
def __call__(self, model, dataset): | |
"""Train model on dataset.""" | |
if self.reload_state_dict is not None: | |
set_peft_weights(model, self.reload_state_dict) | |
# Train the model | |
training_run(model, dataset=self.preprocess(dataset), **self.kwargs) | |
def calc_score(key, input, reply, formatter, model, store=None, decoder=None, **_): | |
"""Calculate score for a given input-reply pair.""" | |
# Tokenize input and reply | |
input_len = len(formatter.tokenizer(input)['input_ids']) | |
tokenized = formatter.tokenizer([input+reply], return_tensors='jax') | |
# Get reply tokens | |
reply_tok = np.array(tokenized['input_ids'][0][input_len:]) | |
# Get logits for scoring | |
logits = model( | |
input_ids=tokenized['input_ids'], | |
attention_mask=tokenized['attention_mask'] | |
).logits[0, input_len-1:-1] | |
# Convert logits to numpy array | |
reply_log = np.array(logits) | |
# Process output | |
process_inference_output( | |
key, | |
(reply_tok[np.newaxis], reply_log[np.newaxis]), | |
formatter, | |
store=store, | |
decoder=decoder | |
) | |
def mem_info(device_id=0): | |
"""Print memory usage information.""" | |
try: | |
# JAX has a different memory tracking mechanism than PyTorch | |
devices = jax.devices() | |
if device_id < len(devices): | |
device = devices[device_id] | |
memory_info = jax.device_get(jax.jit(lambda: jax.device_memory_profile(device))()) | |
total_memory = memory_info.get('total_memory', 0) / 1024**3 | |
used_memory = memory_info.get('used_memory', 0) / 1024**3 | |
logger.info(f"*** Device: {device}, used {used_memory:.3} / {total_memory:.3} GB.") | |
else: | |
logger.info(f"*** Device ID {device_id} not found among available devices.") | |
except: | |
logger.info('*** Exception occurred when getting memory stats.') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment