Skip to content

Instantly share code, notes, and snippets.

@ericflo
Last active December 4, 2024 22:12
Show Gist options
  • Save ericflo/606e6caf5d0f5cd93cde7414c43c5cac to your computer and use it in GitHub Desktop.
Save ericflo/606e6caf5d0f5cd93cde7414c43c5cac to your computer and use it in GitHub Desktop.
import copy
import json
import os
import logging
import random
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
import fire
from openai import OpenAI
from transformers import AutoTokenizer
from datasets import Dataset, load_dataset
from tqdm import tqdm
from rank_cot_dataset import rank_question
BASE_URL = "http://127.0.0.1:8086/v1"
API_KEY = "asdf"
ELICIT_PROMPT = "\n\nNOTE: Before you respond, use...\n<thoughts>\n<thought>...</thought>\n<thought>...</thought>\n...\n</thoughts>\n to jot down various thoughts, facts, ideas, plans, algorithms, or anything you think will help yourself answer better."
ELICIT_LEVEL = 1 # 0, 1, or 2
OUTPUT_FILE = "outputs.jsonl"
BEST_OUTPUT_FILE = "best_outputs.jsonl"
TRAINING_FILE = "training_data.jsonl"
TRAINING_FILE0 = "training_data0.jsonl"
TOKENIZER_NAME = "meta-llama/Llama-3.2-3B-Instruct"
RETRY_ATTEMPTS = 3
NUM_WORKERS = 96
DEPTH = 6
DATASET_SEED = 58
ALPHANUMERIC = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
CAP_ALNUM_LIST = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
USE_TRAINING_DATA_AS_HISTORY = True
# RANK_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai"
# RANK_MODEL = "gemini-1.5-flash-8b"
RANK_API_KEY = "asdf"
RANK_BASE_URL = "http://127.0.0.1:8087/v1"
RANK_MODEL = "allenai/Llama-3.1-Tulu-3-8B"
STOP_WORDS = [
"</thought>",
"</thoughts>",
"<thought",
"<|im_end|>",
"<|eot_id|>",
"<|end_header_id|>",
"<|start_header_id|>",
"<|end",
"<|start",
]
VERBOSE = True
# logging.basicConfig(level=logging.INFO)
@dataclass
class Message:
content: str
elicit_prompt: str = ELICIT_PROMPT
@property
def full_content(self) -> str:
if ELICIT_LEVEL > 0:
return f"{self.content}{self.elicit_prompt}"
return self.content
def write_jsonl(data, filepath):
"""Write a single result to jsonl file and flush"""
try:
with open(filepath, "a") as f:
json.dump(data, f)
f.write("\n")
f.flush()
if VERBOSE and filepath == TRAINING_FILE:
print(f"Successfully wrote training data to {filepath}")
except (KeyboardInterrupt, SystemExit):
raise
except Exception as e:
logging.error(f"Error writing to {filepath}: {str(e)}")
traceback.print_exc()
def load_progress():
"""Load previously processed prompts and results"""
completed_prompts = set()
if os.path.exists(OUTPUT_FILE):
with open(OUTPUT_FILE, "r") as f:
for line in f:
try:
data = json.loads(line.strip())
completed_prompts.add(data["prompt"])
except json.JSONDecodeError:
logging.warning("Skipping corrupted line in output file")
continue
return completed_prompts
def safe_api_call(func, *args, max_retries=RETRY_ATTEMPTS, **kwargs):
"""Make API calls with retry logic"""
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (KeyboardInterrupt, SystemExit):
raise
except Exception as e:
logging.warning(f"API call failed (attempt {attempt + 1}): {str(e)}")
if attempt == max_retries - 1:
return None
continue
def get_thought_response(
client, client_model, prompt, thoughts, suffix, thought_chunk_size
):
"""Get thought response with error handling"""
try:
full_prompt = f"{prompt}<thoughts>\n"
for thought in thoughts:
full_prompt += f"<thought>{thought}</thought>\n"
full_prompt += f"<thought>{suffix}"
output = safe_api_call(
client.completions.create,
model=client_model.id,
prompt=full_prompt,
max_tokens=thought_chunk_size,
temperature=0,
stop=STOP_WORDS + ["\n"],
frequency_penalty=1.05,
)
if output:
return suffix + output.choices[0].text.strip()
return None
except (KeyboardInterrupt, SystemExit):
raise
except Exception as e:
logging.error(f"Error in get_thought_response: {str(e)}")
return None
def get_response(client, client_model, prompt, thoughts, max_tokens):
"""Get response with error handling"""
try:
full_prompt = f"{prompt}<thoughts>\n"
for thought in thoughts:
full_prompt += f"<thought>{thought}</thought>\n"
full_prompt += "</thoughts>\n\n"
output = safe_api_call(
client.completions.create,
model=client_model.id,
prompt=full_prompt,
max_tokens=max_tokens,
temperature=0,
stop=STOP_WORDS,
)
if output:
return output.choices[0].text.strip()
return None
except (KeyboardInterrupt, SystemExit):
raise
except Exception as e:
logging.error(f"Error in get_response: {str(e)}")
return None
def setup_clients():
client = OpenAI(base_url=BASE_URL, api_key=API_KEY)
client_model = client.models.list().data[0]
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
return client, client_model, tokenizer
def prepare_dataset(sample_size, convo_count, random_seed):
dataset = load_dataset("allenai/tulu-3-sft-mixture")["train"]
history_convos = list(
map(
dict,
dataset.shuffle(seed=random_seed).select(range(sample_size + convo_count)),
)
)
return history_convos
def extract_prompt_messages(history_convos, convo_count):
prompt_messages = []
for _ in range(convo_count):
msg = history_convos.pop(0)["messages"][0]
prompt_messages.append(Message(content=msg["content"]))
return prompt_messages
def prepare_messages(
message, history_convos, tokenizer, max_tokens, output_tokens, reserved_tokens
):
msgs = []
msg = {"role": "user", "content": message.full_content}
convs = []
if ELICIT_LEVEL > 0:
if ELICIT_LEVEL == 2:
convs = history_convos
random.shuffle(convs)
elif ELICIT_LEVEL == 1:
convs = random.sample(history_convos, 3)
else:
raise ValueError(f"Invalid ELICIT_LEVEL: {ELICIT_LEVEL}")
else:
convs = []
for convo in convs:
potential_prompt = tokenizer.apply_chat_template(
copy.deepcopy(msgs) + convo["messages"] + [msg],
tokenize=True,
add_generation_prompt=True,
)
if len(potential_prompt) > max_tokens - output_tokens - reserved_tokens:
break
msgs.extend(convo["messages"])
msgs.append(msg)
return msgs
def process_single_message(args):
(
message,
history_convos,
client,
client_model,
tokenizer,
max_tokens,
depth,
output_tokens,
reserved_tokens,
thought_chunk_size,
) = args
try:
msgs = prepare_messages(
message,
history_convos,
tokenizer,
max_tokens,
output_tokens,
reserved_tokens,
)
base_prompt = tokenizer.apply_chat_template(
msgs, tokenize=False, add_generation_prompt=True
)
rank_client = OpenAI(api_key=RANK_API_KEY, base_url=RANK_BASE_URL)
thought_stack = []
answer_stack = []
all_results = []
best_results = []
for _ith in range(depth):
# if VERBOSE:
# thought_stack_fmt = " ".join(((f"({t[:10]})" for t in thought_stack)))
# Develop thoughts
new_thoughts = []
for char in ALPHANUMERIC:
new_thought = get_thought_response(
client,
client_model,
base_prompt,
thought_stack,
char,
thought_chunk_size,
)
if new_thought:
# if VERBOSE:
# print(
# f"THOUGHT: {{{len(thought_stack)}}} [Q:{json.dumps(message.content[-30:])}] {thought_stack_fmt} {new_thought}"
# )
new_thoughts.append(new_thought)
# Generate results for each thought
results = []
seen_mini_ids = set()
for thought in new_thoughts:
if len(thought) <= 2:
continue
response = get_response(
client,
client_model,
base_prompt,
thought_stack + [thought],
output_tokens,
)
if not response:
continue
if len(response) > 0:
# if VERBOSE:
# print(
# f"RESPONSE: {{{len(thought_stack)}}} [Q:{json.dumps(message.content[-30:])}] {thought_stack_fmt} ({thought[:10]}) {json.dumps(response)[-100:]}"
# )
mini_id = ""
while mini_id == "" or mini_id in seen_mini_ids:
mini_id = f"{random.choice(CAP_ALNUM_LIST)}{random.choice(CAP_ALNUM_LIST)}"
seen_mini_ids.add(mini_id)
results.append(
{
"mini_id": mini_id,
"prompt": message.content,
"base_prompt": base_prompt,
"thoughts": copy.deepcopy(thought_stack) + [thought],
"answer": response,
}
)
if not results:
ranked_best_results = rank_question(
client=rank_client, rows=best_results, model=RANK_MODEL
)
for rank, result in enumerate(ranked_best_results):
result["rank"] = rank
return all_results, ranked_best_results
# Rank the thoughts by their results
ranked_results = rank_question(
client=rank_client, rows=results, model=RANK_MODEL
)
if not ranked_results:
print(f"WTF THIS SHOULD NOT HAPPEN {len(results)}")
ranked_best_results = rank_question(
client=rank_client, rows=best_results, model=RANK_MODEL
)
for rank, result in enumerate(ranked_best_results):
result["rank"] = rank
return all_results, ranked_best_results
for rank, result in enumerate(ranked_results):
result["rank"] = rank
thought_stack.append(ranked_results[0]["thoughts"][-1])
answer_stack.append(ranked_results[0]["answer"])
all_results.extend(ranked_results)
best_results.insert(0, ranked_results[0])
ranked_best_results = rank_question(
client=rank_client, rows=best_results, model=RANK_MODEL
)
for rank, result in enumerate(ranked_best_results):
result["rank"] = rank
return all_results, ranked_best_results
except (KeyboardInterrupt, SystemExit):
raise
except Exception as e:
# logging.error(f"Error processing prompt {message.content}: {str(e)}")
exc = "\n".join(traceback.format_exception(e))
logging.error(f"Error processing prompt: {exc}")
ranked_best_results = rank_question(
client=rank_client, rows=best_results, model=RANK_MODEL
)
for rank, result in enumerate(ranked_best_results):
result["rank"] = rank
return all_results, ranked_best_results
def process_messages_parallel(prompt_messages, history_convos, config):
client, client_model, tokenizer = setup_clients()
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
args_list = [
(
msg,
history_convos,
client,
client_model,
tokenizer,
config["max_tokens"],
config["depth"],
config["output_tokens"],
config["reserved_tokens"],
config["thought_chunk_size"],
)
for msg in prompt_messages
]
futures = [executor.submit(process_single_message, args) for args in args_list]
all_results = []
# best_results = []
for future in tqdm(
as_completed(futures), total=len(futures), desc="Processing prompts"
):
try:
results, best_batch_results = future.result()
if VERBOSE:
print(f"Got {len(best_batch_results)} best results from batch")
all_results.extend(results)
# Write results as they complete
for result in results:
write_jsonl(result, OUTPUT_FILE)
for result in best_batch_results:
write_jsonl(result, BEST_OUTPUT_FILE)
if best_batch_results:
if VERBOSE:
print(
f"Processing {len(best_batch_results)} best results for training data"
)
for best in best_batch_results:
try:
thoughts_fmt = f"<thoughts>\n"
for thought in best["thoughts"]:
thoughts_fmt += f"<thought>{thought}</thought>\n"
thoughts_fmt += f"</thoughts>\n\n{best['answer']}"
train_row = {
"messages": [
{"role": "user", "content": best["prompt"]},
{"role": "assistant", "content": thoughts_fmt},
]
}
write_jsonl(train_row, TRAINING_FILE)
except (SystemExit, KeyboardInterrupt):
raise
except Exception as e:
logging.error(
f"Error processing best result for training: {str(e)}"
)
traceback.print_exc()
# Break on purpose, we only want the first row
break
except (SystemExit, KeyboardInterrupt):
raise
except Exception as e:
logging.error(f"Error processing future result: {str(e)}")
traceback.print_exc()
return all_results
def main(
sample_size: int = 100,
random_seed: int = DATASET_SEED,
output_tokens: int = 4096,
reserved_tokens: int = 1024,
max_tokens: int = 65536,
thought_chunk_size: int = 128,
):
config = {
"sample_size": sample_size,
"random_seed": random_seed,
"output_tokens": output_tokens,
"reserved_tokens": reserved_tokens,
"max_tokens": max_tokens,
"depth": DEPTH,
"thought_chunk_size": thought_chunk_size,
}
completed_prompts = load_progress()
history_convos = prepare_dataset(sample_size, 1000, random_seed)
prompt_messages = extract_prompt_messages(history_convos, 1000)
# Filter out already completed prompts
prompt_messages = [
msg for msg in prompt_messages if msg.content not in completed_prompts
]
if USE_TRAINING_DATA_AS_HISTORY and os.path.exists(TRAINING_FILE0):
with open(TRAINING_FILE0, "r") as f:
old_history_convos = history_convos
history_convos = [json.loads(line.strip()) for line in f if line.strip()]
if VERBOSE:
print(
f"Replacing historical convos ({len(old_history_convos)}) with generated ones with thoughts ({len(history_convos)})"
)
results = process_messages_parallel(prompt_messages, history_convos, config)
return results
if __name__ == "__main__":
fire.Fire(main)
import json
import random
import fire
# from transformers import AutoTokenizer
def main(
*filenames,
out_filename: str = "merged_training_data.jsonl",
expand_factor: int = 4,
include_system: bool = True,
include_system_pct: float = 0.4,
system_thoughts_pct: float = 0.75,
system_number_pct: float = 0.5,
random_seed: int = 42,
):
random.seed(random_seed)
data = []
seen_prompts = set()
for filename in filenames:
print(f"Opening {filename}")
with open(filename) as f:
for line in f:
line = line.strip()
if not line:
continue
row = json.loads(line)
if row["prompt"] in seen_prompts:
continue
seen_prompts.add(row["prompt"])
data.append(row)
print(f"Found {len(data)} records")
# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
records_written = 0
with open(out_filename, "w") as f:
for _ in range(expand_factor):
random.shuffle(data)
for row in data:
response = f"<thoughts>\n"
for thought in row["thoughts"]:
response += f"<thought>{thought}</thought>\n"
response += f"</thoughts>\n\n{row['answer']}"
messages = []
if include_system and random.random() < include_system_pct:
system_message = "You are a helpful assistant."
if random.random() < system_thoughts_pct:
if random.random() < system_number_pct:
system_message += f" Think {len(row['thoughts'])} thought{'' if len(row['thoughts']) == 1 else 's'} before responding."
else:
system_message += " Think before responding."
messages += [{"role": "system", "content": system_message}]
messages += [
{"role": "user", "content": row["prompt"]},
{"role": "assistant", "content": response},
]
# prompt = tokenizer.apply_chat_template(messages, tokenize=False)
f.write(json.dumps({"messages": messages}) + "\n")
records_written += 1
print(f"Wrote {records_written} records to {out_filename}")
if __name__ == "__main__":
fire.Fire(main)
import json
import random
import math
from typing import Optional
from openai import OpenAI
from tqdm import tqdm
import time
# Constants (keeping existing constants)
SYNCHRONOUS = False
API_KEY = "asdf"
BASE_URL = "http://localhost:8085/v1"
MODEL = "meta-llama/Llama-3.2-3B-Instruct"
SYSTEM_MESSAGE = """
You are an expert evaluator of question-answer pairs, tasked with analyzing and ranking responses based on their quality. You will assess each (id, question, answer) triplet and output a JSON array of IDs ranked from highest to lowest quality.
Primary Evaluation Criteria (70% weight):
1. Question-Answer Alignment & Relevance (25%)
- Complete addressing of all explicit and implicit requirements
- Direct relevance to the query intent
- Appropriate scope (neither over-answering nor under-answering)
- Recognition of context and constraints
- Anticipation of related follow-up questions
2. Information Quality & Accuracy (25%)
- Factual correctness and precision
- Logical coherence and flow
- Evidence and citations when needed
- Absence of hallucinations or fabrications
- Appropriate confidence level
- Handling of edge cases and assumptions
3. Communication Effectiveness (20%)
- Clear, well-structured presentation
- Appropriate technical depth for audience
- Efficient use of language
- Professional and engaging tone
- Accessibility and readability
- Effective use of formatting/organization
Secondary Criteria (30% weight):
4. Cognitive Value (10%)
- Critical thinking demonstration
- Novel perspectives and insights
- Connection-making ability
- Problem-solving approach
- Educational value
5. Practical Value (10%)
- Real-world applicability
- Actionable advice or steps
- Implementation feasibility
- Time relevance
- Resource considerations
6. Ethical & Social Aspects (10%)
- Cultural sensitivity and awareness
- Bias detection and fairness
- Ethical implications
- Source attribution
- Impact consideration
Automatic Disqualifiers:
- Harmful or unethical content
- Deliberately misleading information
- Complete topic misalignment
- Severe factual errors
- Internal contradictions
- Nonsensical or incoherent content
Special Considerations:
- Question type and purpose
- Audience expertise level
- Time sensitivity
- Cultural context
- Required evidence/citations
- Expected format requirements
Ranking Process:
1. Evaluate each response against weighted criteria
2. Apply context-specific requirements
3. Check for disqualifying issues
4. Consider special factors
5. Sort responses from highest to lowest quality
6. Verify all IDs included exactly once
Output Format:
{"ranking": ["ID1", "ID2", "ID3", ...]}
Your output must be clean JSON only, with no additional text or formatting. If ranking cannot be determined, output {"ranking": []}.
""".strip()
class Player:
def __init__(self, id_str: str):
self.id = id_str
self.mu = 25.0 # Initial mean
self.sigma = 8.33 # Initial standard deviation
self.num_comparisons = 0
@property
def rating(self):
return self.mu - 3 * self.sigma # Conservative rating estimate
def _qa(row: dict) -> str:
return f"""
<QuestionAnswerPair>
<id>{row['mini_id']}</id>
<Question>{row['prompt']}</Question>
<Answer>{row['answer']}</Answer>
</QuestionAnswerPair>
""".strip()
def _validate_ranking(ranking: list, rows: list[dict]) -> Optional[str]:
if not isinstance(ranking, list):
return f"Invalid ranking (not a list): {ranking}"
if not all(isinstance(item, str) for item in ranking):
return f"Invalid ranking (not a list of strings): {ranking}"
if len(ranking) != len(rows):
return f"Invalid ranking (wrong length): {ranking}"
if len(set(ranking)) != len(ranking):
return f"Invalid ranking (duplicates): {ranking}"
if set(ranking) != set(row["mini_id"] for row in rows):
return f"Invalid ranking (wrong ids): {ranking}"
return None
def get_ranking(
client: OpenAI,
rows: list[dict],
model: str = MODEL,
temperature: float = 0.0,
num_attempts: int = 2,
) -> Optional[list[str]]:
"""Simplified get_ranking with single attempt and better error handling"""
if not rows:
return rows
user_content = "\n".join(_qa(row=row) for row in rows)
messages = [
{"role": "system", "content": SYSTEM_MESSAGE},
{
"role": "user",
"content": f"<QuestionAnswerPairs>\n{user_content}\n</QuestionAnswerPairs>",
},
]
schema = {
"type": "object",
"properties": {
"ranking": {
"type": "array",
"items": {"type": "string"},
"description": "The ranking of the QAPairs, from best to worst.",
},
},
"required": ["ranking"],
"additionalProperties": False,
}
for _ in range(num_attempts):
content = None
try:
response = client.chat.completions.create(
model=model,
temperature=temperature,
messages=messages,
response_format={
"type": "json_schema",
"json_schema": {
"name": "ranking",
"strict": True,
"schema": schema,
},
},
)
content = response.choices[0].message.content.strip()
ranking = json.loads(content).get("ranking", []) or []
validation_error = _validate_ranking(ranking, rows)
if not validation_error:
return ranking
messages.append({"role": "assistant", "content": content})
messages.append(
{
"role": "user",
"content": f"<ValidationError>\n{validation_error}\n</ValidationError>\n<MessageToRanker>\nPlease try again (to rank the given QA pairs based on the provided criteria) but correct the validation error\n</MessageToRanker>\n<RepeatedQuestion>\n<QuestionAnswerPairs>\n{user_content}\n</QuestionAnswerPairs>\n</RepeatedQuestion>",
}
)
except (SystemExit, KeyboardInterrupt):
raise
except Exception as e:
print(f"WARNING: {e}")
if content:
messages.append({"role": "assistant", "content": content})
messages.append(
{
"role": "user",
"content": f"ERROR: {str(e)}",
}
)
else:
raise ValueError(
f"Failed to get valid ranking ({num_attempts}) - INFERENCE_ERROR"
)
raise ValueError(f"Failed to get valid ranking after {num_attempts} attempts")
def update_ratings(winner: Player, loser: Player, k: float = 0.5):
"""Update ratings using Bayesian update with dynamic K-factor"""
# Calculate dynamic K-factor based on number of comparisons
k_factor = k / math.sqrt(max(1, min(winner.num_comparisons, loser.num_comparisons)))
expected_win = 1.0 / (
1.0
+ math.exp(
(loser.mu - winner.mu) / math.sqrt(2.0 * (winner.sigma**2 + loser.sigma**2))
)
)
# Update step sizes based on uncertainty
winner_step = k_factor * winner.sigma
loser_step = k_factor * loser.sigma
# Update means
winner.mu += winner_step * (1 - expected_win)
loser.mu -= loser_step * (1 - expected_win)
# Update standard deviations
winner.sigma *= math.sqrt(
max(0.95, 1.0 - k_factor * expected_win * (1 - expected_win))
)
loser.sigma *= math.sqrt(
max(0.95, 1.0 - k_factor * expected_win * (1 - expected_win))
)
# Increment comparison counts
winner.num_comparisons += 1
loser.num_comparisons += 1
def rank_question(
client: OpenAI,
rows: list[dict],
model: str = MODEL,
group_size: int = 3,
num_groups: Optional[int] = None,
) -> Optional[list[dict]]:
"""Main ranking function using Bayesian skill ratings with comparison tracking"""
if not rows:
return rows
num_items = len(rows)
# Calculate optimal number of groups
if num_groups is None:
try:
# We need roughly n*log(n) comparisons total for good ranking
target_comparisons = int(num_items * math.log2(num_items))
# Each group of size k generates k(k-1)/2 comparisons
comparisons_per_group = (group_size * (group_size - 1)) / 2
# Calculate groups needed to get close to target comparisons
num_groups = max(
2, # Minimum of 2 groups
min(
int(target_comparisons / comparisons_per_group),
2 * math.ceil(num_items / group_size), # Cap at 2x coverage
),
)
except (ValueError, ZeroDivisionError):
num_groups = max(2, math.ceil(num_items / max(1, group_size)))
# Initialize players
players = {row["mini_id"]: Player(row["mini_id"]) for row in rows}
compared_pairs = set() # Track all compared pairs
# Generate and process groups
for _ in tqdm(range(num_groups), desc="Ranking groups"):
# Select items preferring those with fewer comparisons
weights = [1.0 / (players[row["mini_id"]].num_comparisons + 1) for row in rows]
group = random.choices(rows, weights=weights, k=min(group_size, len(rows)))
# Skip if all pairs in this group have been compared
group_ids = tuple(sorted(row["mini_id"] for row in group))
if all(
(a, b) in compared_pairs
for i, a in enumerate(group_ids)
for b in group_ids[i + 1 :]
):
continue
try:
ranking = get_ranking(client=client, rows=group, model=model)
except ValueError as e:
print(f"WARN: {e}")
# If ranking fails, skip this group
continue
if ranking is None:
continue
# Update ratings based on ranking
for i in range(len(ranking)):
for j in range(i + 1, len(ranking)):
winner_id, loser_id = ranking[i], ranking[j]
if (winner_id, loser_id) in compared_pairs:
continue
compared_pairs.add((winner_id, loser_id))
try:
update_ratings(players[winner_id], players[loser_id])
except (KeyboardInterrupt, SystemExit):
raise
except Exception as e:
print(f"Error updating ratings: {e}")
continue
# Sort rows based on final ratings
sorted_ids = sorted(players.keys(), key=lambda x: players[x].rating, reverse=True)
id_to_row = {row["mini_id"]: row for row in rows}
ranked_rows = [id_to_row[id_str] for id_str in sorted_ids]
return ranked_rows
"""
python src/minicot/train.py \
python train_multi.py \
accelerate launch --num_processes 7 train_multi.py \
--model_name_or_path meta-llama/Llama-3.2-3B-Instruct \
--dataset_name merged_training_data.jsonl \
--learning_rate 1e-4 \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--gradient_checkpointing \
--logging_steps 1 \
--warmup_ratio 0.1 \
--eval_strategy steps \
--eval_steps 20 \
--torch_dtype bfloat16 \
--max_seq_length 16384 \
--use_liger \
--attn_implementation flash_attention_2 \
--optim adamw_bnb_8bit \
--output_dir Llama-3.2-3B-COTv2.3
"""
import random
from datasets import load_dataset, Dataset, DatasetDict
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import (
ModelConfig,
ScriptArguments,
SFTConfig,
SFTTrainer,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
setup_chat_format,
)
if __name__ == "__main__":
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
script_args, training_args, model_config = parser.parse_args_and_config()
################
# Model init kwargs & Tokenizer
################
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=model_config.torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path,
trust_remote_code=model_config.trust_remote_code,
use_fast=True,
)
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.chat_template is None:
model = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path,
**model_kwargs,
)
model, tokenizer = setup_chat_format(model, tokenizer, format="chatml")
else:
model = model_config.model_name_or_path
################
# Dataset
################
# max_group_size = 16
# num_repeats = 1
# dataset = load_dataset("json", data_files=script_args.dataset_name)
# rows = [dict(r) for r in dataset["train"]]
# repeated_rows = []
# for _ in range(num_repeats):
# random.shuffle(rows)
# repeated_rows.extend(rows)
# grouped = []
# while repeated_rows:
# row_group_size = random.randint(1, max_group_size)
# row_group = repeated_rows[:row_group_size]
# repeated_rows = repeated_rows[row_group_size:]
# msgs = []
# for row in row_group:
# msgs.extend(row["messages"])
# grouped.append({"messages": msgs})
# dataset = Dataset.from_list(grouped).train_test_split(test_size=0.05)
dataset = load_dataset("json", data_files=script_args.dataset_name)[
"train"
].train_test_split(test_size=0.05)
################
# Training
################
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=(
dataset[script_args.dataset_test_split]
if training_args.eval_strategy != "no"
else None
),
processing_class=tokenizer,
peft_config=get_peft_config(model_config),
)
trainer.train()
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment