Last active
December 4, 2024 22:12
-
-
Save ericflo/606e6caf5d0f5cd93cde7414c43c5cac to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
import json | |
import os | |
import logging | |
import random | |
import traceback | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
from dataclasses import dataclass | |
import fire | |
from openai import OpenAI | |
from transformers import AutoTokenizer | |
from datasets import Dataset, load_dataset | |
from tqdm import tqdm | |
from rank_cot_dataset import rank_question | |
BASE_URL = "http://127.0.0.1:8086/v1" | |
API_KEY = "asdf" | |
ELICIT_PROMPT = "\n\nNOTE: Before you respond, use...\n<thoughts>\n<thought>...</thought>\n<thought>...</thought>\n...\n</thoughts>\n to jot down various thoughts, facts, ideas, plans, algorithms, or anything you think will help yourself answer better." | |
ELICIT_LEVEL = 1 # 0, 1, or 2 | |
OUTPUT_FILE = "outputs.jsonl" | |
BEST_OUTPUT_FILE = "best_outputs.jsonl" | |
TRAINING_FILE = "training_data.jsonl" | |
TRAINING_FILE0 = "training_data0.jsonl" | |
TOKENIZER_NAME = "meta-llama/Llama-3.2-3B-Instruct" | |
RETRY_ATTEMPTS = 3 | |
NUM_WORKERS = 96 | |
DEPTH = 6 | |
DATASET_SEED = 58 | |
ALPHANUMERIC = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" | |
CAP_ALNUM_LIST = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") | |
USE_TRAINING_DATA_AS_HISTORY = True | |
# RANK_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai" | |
# RANK_MODEL = "gemini-1.5-flash-8b" | |
RANK_API_KEY = "asdf" | |
RANK_BASE_URL = "http://127.0.0.1:8087/v1" | |
RANK_MODEL = "allenai/Llama-3.1-Tulu-3-8B" | |
STOP_WORDS = [ | |
"</thought>", | |
"</thoughts>", | |
"<thought", | |
"<|im_end|>", | |
"<|eot_id|>", | |
"<|end_header_id|>", | |
"<|start_header_id|>", | |
"<|end", | |
"<|start", | |
] | |
VERBOSE = True | |
# logging.basicConfig(level=logging.INFO) | |
@dataclass | |
class Message: | |
content: str | |
elicit_prompt: str = ELICIT_PROMPT | |
@property | |
def full_content(self) -> str: | |
if ELICIT_LEVEL > 0: | |
return f"{self.content}{self.elicit_prompt}" | |
return self.content | |
def write_jsonl(data, filepath): | |
"""Write a single result to jsonl file and flush""" | |
try: | |
with open(filepath, "a") as f: | |
json.dump(data, f) | |
f.write("\n") | |
f.flush() | |
if VERBOSE and filepath == TRAINING_FILE: | |
print(f"Successfully wrote training data to {filepath}") | |
except (KeyboardInterrupt, SystemExit): | |
raise | |
except Exception as e: | |
logging.error(f"Error writing to {filepath}: {str(e)}") | |
traceback.print_exc() | |
def load_progress(): | |
"""Load previously processed prompts and results""" | |
completed_prompts = set() | |
if os.path.exists(OUTPUT_FILE): | |
with open(OUTPUT_FILE, "r") as f: | |
for line in f: | |
try: | |
data = json.loads(line.strip()) | |
completed_prompts.add(data["prompt"]) | |
except json.JSONDecodeError: | |
logging.warning("Skipping corrupted line in output file") | |
continue | |
return completed_prompts | |
def safe_api_call(func, *args, max_retries=RETRY_ATTEMPTS, **kwargs): | |
"""Make API calls with retry logic""" | |
for attempt in range(max_retries): | |
try: | |
return func(*args, **kwargs) | |
except (KeyboardInterrupt, SystemExit): | |
raise | |
except Exception as e: | |
logging.warning(f"API call failed (attempt {attempt + 1}): {str(e)}") | |
if attempt == max_retries - 1: | |
return None | |
continue | |
def get_thought_response( | |
client, client_model, prompt, thoughts, suffix, thought_chunk_size | |
): | |
"""Get thought response with error handling""" | |
try: | |
full_prompt = f"{prompt}<thoughts>\n" | |
for thought in thoughts: | |
full_prompt += f"<thought>{thought}</thought>\n" | |
full_prompt += f"<thought>{suffix}" | |
output = safe_api_call( | |
client.completions.create, | |
model=client_model.id, | |
prompt=full_prompt, | |
max_tokens=thought_chunk_size, | |
temperature=0, | |
stop=STOP_WORDS + ["\n"], | |
frequency_penalty=1.05, | |
) | |
if output: | |
return suffix + output.choices[0].text.strip() | |
return None | |
except (KeyboardInterrupt, SystemExit): | |
raise | |
except Exception as e: | |
logging.error(f"Error in get_thought_response: {str(e)}") | |
return None | |
def get_response(client, client_model, prompt, thoughts, max_tokens): | |
"""Get response with error handling""" | |
try: | |
full_prompt = f"{prompt}<thoughts>\n" | |
for thought in thoughts: | |
full_prompt += f"<thought>{thought}</thought>\n" | |
full_prompt += "</thoughts>\n\n" | |
output = safe_api_call( | |
client.completions.create, | |
model=client_model.id, | |
prompt=full_prompt, | |
max_tokens=max_tokens, | |
temperature=0, | |
stop=STOP_WORDS, | |
) | |
if output: | |
return output.choices[0].text.strip() | |
return None | |
except (KeyboardInterrupt, SystemExit): | |
raise | |
except Exception as e: | |
logging.error(f"Error in get_response: {str(e)}") | |
return None | |
def setup_clients(): | |
client = OpenAI(base_url=BASE_URL, api_key=API_KEY) | |
client_model = client.models.list().data[0] | |
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) | |
return client, client_model, tokenizer | |
def prepare_dataset(sample_size, convo_count, random_seed): | |
dataset = load_dataset("allenai/tulu-3-sft-mixture")["train"] | |
history_convos = list( | |
map( | |
dict, | |
dataset.shuffle(seed=random_seed).select(range(sample_size + convo_count)), | |
) | |
) | |
return history_convos | |
def extract_prompt_messages(history_convos, convo_count): | |
prompt_messages = [] | |
for _ in range(convo_count): | |
msg = history_convos.pop(0)["messages"][0] | |
prompt_messages.append(Message(content=msg["content"])) | |
return prompt_messages | |
def prepare_messages( | |
message, history_convos, tokenizer, max_tokens, output_tokens, reserved_tokens | |
): | |
msgs = [] | |
msg = {"role": "user", "content": message.full_content} | |
convs = [] | |
if ELICIT_LEVEL > 0: | |
if ELICIT_LEVEL == 2: | |
convs = history_convos | |
random.shuffle(convs) | |
elif ELICIT_LEVEL == 1: | |
convs = random.sample(history_convos, 3) | |
else: | |
raise ValueError(f"Invalid ELICIT_LEVEL: {ELICIT_LEVEL}") | |
else: | |
convs = [] | |
for convo in convs: | |
potential_prompt = tokenizer.apply_chat_template( | |
copy.deepcopy(msgs) + convo["messages"] + [msg], | |
tokenize=True, | |
add_generation_prompt=True, | |
) | |
if len(potential_prompt) > max_tokens - output_tokens - reserved_tokens: | |
break | |
msgs.extend(convo["messages"]) | |
msgs.append(msg) | |
return msgs | |
def process_single_message(args): | |
( | |
message, | |
history_convos, | |
client, | |
client_model, | |
tokenizer, | |
max_tokens, | |
depth, | |
output_tokens, | |
reserved_tokens, | |
thought_chunk_size, | |
) = args | |
try: | |
msgs = prepare_messages( | |
message, | |
history_convos, | |
tokenizer, | |
max_tokens, | |
output_tokens, | |
reserved_tokens, | |
) | |
base_prompt = tokenizer.apply_chat_template( | |
msgs, tokenize=False, add_generation_prompt=True | |
) | |
rank_client = OpenAI(api_key=RANK_API_KEY, base_url=RANK_BASE_URL) | |
thought_stack = [] | |
answer_stack = [] | |
all_results = [] | |
best_results = [] | |
for _ith in range(depth): | |
# if VERBOSE: | |
# thought_stack_fmt = " ".join(((f"({t[:10]})" for t in thought_stack))) | |
# Develop thoughts | |
new_thoughts = [] | |
for char in ALPHANUMERIC: | |
new_thought = get_thought_response( | |
client, | |
client_model, | |
base_prompt, | |
thought_stack, | |
char, | |
thought_chunk_size, | |
) | |
if new_thought: | |
# if VERBOSE: | |
# print( | |
# f"THOUGHT: {{{len(thought_stack)}}} [Q:{json.dumps(message.content[-30:])}] {thought_stack_fmt} {new_thought}" | |
# ) | |
new_thoughts.append(new_thought) | |
# Generate results for each thought | |
results = [] | |
seen_mini_ids = set() | |
for thought in new_thoughts: | |
if len(thought) <= 2: | |
continue | |
response = get_response( | |
client, | |
client_model, | |
base_prompt, | |
thought_stack + [thought], | |
output_tokens, | |
) | |
if not response: | |
continue | |
if len(response) > 0: | |
# if VERBOSE: | |
# print( | |
# f"RESPONSE: {{{len(thought_stack)}}} [Q:{json.dumps(message.content[-30:])}] {thought_stack_fmt} ({thought[:10]}) {json.dumps(response)[-100:]}" | |
# ) | |
mini_id = "" | |
while mini_id == "" or mini_id in seen_mini_ids: | |
mini_id = f"{random.choice(CAP_ALNUM_LIST)}{random.choice(CAP_ALNUM_LIST)}" | |
seen_mini_ids.add(mini_id) | |
results.append( | |
{ | |
"mini_id": mini_id, | |
"prompt": message.content, | |
"base_prompt": base_prompt, | |
"thoughts": copy.deepcopy(thought_stack) + [thought], | |
"answer": response, | |
} | |
) | |
if not results: | |
ranked_best_results = rank_question( | |
client=rank_client, rows=best_results, model=RANK_MODEL | |
) | |
for rank, result in enumerate(ranked_best_results): | |
result["rank"] = rank | |
return all_results, ranked_best_results | |
# Rank the thoughts by their results | |
ranked_results = rank_question( | |
client=rank_client, rows=results, model=RANK_MODEL | |
) | |
if not ranked_results: | |
print(f"WTF THIS SHOULD NOT HAPPEN {len(results)}") | |
ranked_best_results = rank_question( | |
client=rank_client, rows=best_results, model=RANK_MODEL | |
) | |
for rank, result in enumerate(ranked_best_results): | |
result["rank"] = rank | |
return all_results, ranked_best_results | |
for rank, result in enumerate(ranked_results): | |
result["rank"] = rank | |
thought_stack.append(ranked_results[0]["thoughts"][-1]) | |
answer_stack.append(ranked_results[0]["answer"]) | |
all_results.extend(ranked_results) | |
best_results.insert(0, ranked_results[0]) | |
ranked_best_results = rank_question( | |
client=rank_client, rows=best_results, model=RANK_MODEL | |
) | |
for rank, result in enumerate(ranked_best_results): | |
result["rank"] = rank | |
return all_results, ranked_best_results | |
except (KeyboardInterrupt, SystemExit): | |
raise | |
except Exception as e: | |
# logging.error(f"Error processing prompt {message.content}: {str(e)}") | |
exc = "\n".join(traceback.format_exception(e)) | |
logging.error(f"Error processing prompt: {exc}") | |
ranked_best_results = rank_question( | |
client=rank_client, rows=best_results, model=RANK_MODEL | |
) | |
for rank, result in enumerate(ranked_best_results): | |
result["rank"] = rank | |
return all_results, ranked_best_results | |
def process_messages_parallel(prompt_messages, history_convos, config): | |
client, client_model, tokenizer = setup_clients() | |
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor: | |
args_list = [ | |
( | |
msg, | |
history_convos, | |
client, | |
client_model, | |
tokenizer, | |
config["max_tokens"], | |
config["depth"], | |
config["output_tokens"], | |
config["reserved_tokens"], | |
config["thought_chunk_size"], | |
) | |
for msg in prompt_messages | |
] | |
futures = [executor.submit(process_single_message, args) for args in args_list] | |
all_results = [] | |
# best_results = [] | |
for future in tqdm( | |
as_completed(futures), total=len(futures), desc="Processing prompts" | |
): | |
try: | |
results, best_batch_results = future.result() | |
if VERBOSE: | |
print(f"Got {len(best_batch_results)} best results from batch") | |
all_results.extend(results) | |
# Write results as they complete | |
for result in results: | |
write_jsonl(result, OUTPUT_FILE) | |
for result in best_batch_results: | |
write_jsonl(result, BEST_OUTPUT_FILE) | |
if best_batch_results: | |
if VERBOSE: | |
print( | |
f"Processing {len(best_batch_results)} best results for training data" | |
) | |
for best in best_batch_results: | |
try: | |
thoughts_fmt = f"<thoughts>\n" | |
for thought in best["thoughts"]: | |
thoughts_fmt += f"<thought>{thought}</thought>\n" | |
thoughts_fmt += f"</thoughts>\n\n{best['answer']}" | |
train_row = { | |
"messages": [ | |
{"role": "user", "content": best["prompt"]}, | |
{"role": "assistant", "content": thoughts_fmt}, | |
] | |
} | |
write_jsonl(train_row, TRAINING_FILE) | |
except (SystemExit, KeyboardInterrupt): | |
raise | |
except Exception as e: | |
logging.error( | |
f"Error processing best result for training: {str(e)}" | |
) | |
traceback.print_exc() | |
# Break on purpose, we only want the first row | |
break | |
except (SystemExit, KeyboardInterrupt): | |
raise | |
except Exception as e: | |
logging.error(f"Error processing future result: {str(e)}") | |
traceback.print_exc() | |
return all_results | |
def main( | |
sample_size: int = 100, | |
random_seed: int = DATASET_SEED, | |
output_tokens: int = 4096, | |
reserved_tokens: int = 1024, | |
max_tokens: int = 65536, | |
thought_chunk_size: int = 128, | |
): | |
config = { | |
"sample_size": sample_size, | |
"random_seed": random_seed, | |
"output_tokens": output_tokens, | |
"reserved_tokens": reserved_tokens, | |
"max_tokens": max_tokens, | |
"depth": DEPTH, | |
"thought_chunk_size": thought_chunk_size, | |
} | |
completed_prompts = load_progress() | |
history_convos = prepare_dataset(sample_size, 1000, random_seed) | |
prompt_messages = extract_prompt_messages(history_convos, 1000) | |
# Filter out already completed prompts | |
prompt_messages = [ | |
msg for msg in prompt_messages if msg.content not in completed_prompts | |
] | |
if USE_TRAINING_DATA_AS_HISTORY and os.path.exists(TRAINING_FILE0): | |
with open(TRAINING_FILE0, "r") as f: | |
old_history_convos = history_convos | |
history_convos = [json.loads(line.strip()) for line in f if line.strip()] | |
if VERBOSE: | |
print( | |
f"Replacing historical convos ({len(old_history_convos)}) with generated ones with thoughts ({len(history_convos)})" | |
) | |
results = process_messages_parallel(prompt_messages, history_convos, config) | |
return results | |
if __name__ == "__main__": | |
fire.Fire(main) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import random | |
import fire | |
# from transformers import AutoTokenizer | |
def main( | |
*filenames, | |
out_filename: str = "merged_training_data.jsonl", | |
expand_factor: int = 4, | |
include_system: bool = True, | |
include_system_pct: float = 0.4, | |
system_thoughts_pct: float = 0.75, | |
system_number_pct: float = 0.5, | |
random_seed: int = 42, | |
): | |
random.seed(random_seed) | |
data = [] | |
seen_prompts = set() | |
for filename in filenames: | |
print(f"Opening {filename}") | |
with open(filename) as f: | |
for line in f: | |
line = line.strip() | |
if not line: | |
continue | |
row = json.loads(line) | |
if row["prompt"] in seen_prompts: | |
continue | |
seen_prompts.add(row["prompt"]) | |
data.append(row) | |
print(f"Found {len(data)} records") | |
# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct") | |
records_written = 0 | |
with open(out_filename, "w") as f: | |
for _ in range(expand_factor): | |
random.shuffle(data) | |
for row in data: | |
response = f"<thoughts>\n" | |
for thought in row["thoughts"]: | |
response += f"<thought>{thought}</thought>\n" | |
response += f"</thoughts>\n\n{row['answer']}" | |
messages = [] | |
if include_system and random.random() < include_system_pct: | |
system_message = "You are a helpful assistant." | |
if random.random() < system_thoughts_pct: | |
if random.random() < system_number_pct: | |
system_message += f" Think {len(row['thoughts'])} thought{'' if len(row['thoughts']) == 1 else 's'} before responding." | |
else: | |
system_message += " Think before responding." | |
messages += [{"role": "system", "content": system_message}] | |
messages += [ | |
{"role": "user", "content": row["prompt"]}, | |
{"role": "assistant", "content": response}, | |
] | |
# prompt = tokenizer.apply_chat_template(messages, tokenize=False) | |
f.write(json.dumps({"messages": messages}) + "\n") | |
records_written += 1 | |
print(f"Wrote {records_written} records to {out_filename}") | |
if __name__ == "__main__": | |
fire.Fire(main) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import random | |
import math | |
from typing import Optional | |
from openai import OpenAI | |
from tqdm import tqdm | |
import time | |
# Constants (keeping existing constants) | |
SYNCHRONOUS = False | |
API_KEY = "asdf" | |
BASE_URL = "http://localhost:8085/v1" | |
MODEL = "meta-llama/Llama-3.2-3B-Instruct" | |
SYSTEM_MESSAGE = """ | |
You are an expert evaluator of question-answer pairs, tasked with analyzing and ranking responses based on their quality. You will assess each (id, question, answer) triplet and output a JSON array of IDs ranked from highest to lowest quality. | |
Primary Evaluation Criteria (70% weight): | |
1. Question-Answer Alignment & Relevance (25%) | |
- Complete addressing of all explicit and implicit requirements | |
- Direct relevance to the query intent | |
- Appropriate scope (neither over-answering nor under-answering) | |
- Recognition of context and constraints | |
- Anticipation of related follow-up questions | |
2. Information Quality & Accuracy (25%) | |
- Factual correctness and precision | |
- Logical coherence and flow | |
- Evidence and citations when needed | |
- Absence of hallucinations or fabrications | |
- Appropriate confidence level | |
- Handling of edge cases and assumptions | |
3. Communication Effectiveness (20%) | |
- Clear, well-structured presentation | |
- Appropriate technical depth for audience | |
- Efficient use of language | |
- Professional and engaging tone | |
- Accessibility and readability | |
- Effective use of formatting/organization | |
Secondary Criteria (30% weight): | |
4. Cognitive Value (10%) | |
- Critical thinking demonstration | |
- Novel perspectives and insights | |
- Connection-making ability | |
- Problem-solving approach | |
- Educational value | |
5. Practical Value (10%) | |
- Real-world applicability | |
- Actionable advice or steps | |
- Implementation feasibility | |
- Time relevance | |
- Resource considerations | |
6. Ethical & Social Aspects (10%) | |
- Cultural sensitivity and awareness | |
- Bias detection and fairness | |
- Ethical implications | |
- Source attribution | |
- Impact consideration | |
Automatic Disqualifiers: | |
- Harmful or unethical content | |
- Deliberately misleading information | |
- Complete topic misalignment | |
- Severe factual errors | |
- Internal contradictions | |
- Nonsensical or incoherent content | |
Special Considerations: | |
- Question type and purpose | |
- Audience expertise level | |
- Time sensitivity | |
- Cultural context | |
- Required evidence/citations | |
- Expected format requirements | |
Ranking Process: | |
1. Evaluate each response against weighted criteria | |
2. Apply context-specific requirements | |
3. Check for disqualifying issues | |
4. Consider special factors | |
5. Sort responses from highest to lowest quality | |
6. Verify all IDs included exactly once | |
Output Format: | |
{"ranking": ["ID1", "ID2", "ID3", ...]} | |
Your output must be clean JSON only, with no additional text or formatting. If ranking cannot be determined, output {"ranking": []}. | |
""".strip() | |
class Player: | |
def __init__(self, id_str: str): | |
self.id = id_str | |
self.mu = 25.0 # Initial mean | |
self.sigma = 8.33 # Initial standard deviation | |
self.num_comparisons = 0 | |
@property | |
def rating(self): | |
return self.mu - 3 * self.sigma # Conservative rating estimate | |
def _qa(row: dict) -> str: | |
return f""" | |
<QuestionAnswerPair> | |
<id>{row['mini_id']}</id> | |
<Question>{row['prompt']}</Question> | |
<Answer>{row['answer']}</Answer> | |
</QuestionAnswerPair> | |
""".strip() | |
def _validate_ranking(ranking: list, rows: list[dict]) -> Optional[str]: | |
if not isinstance(ranking, list): | |
return f"Invalid ranking (not a list): {ranking}" | |
if not all(isinstance(item, str) for item in ranking): | |
return f"Invalid ranking (not a list of strings): {ranking}" | |
if len(ranking) != len(rows): | |
return f"Invalid ranking (wrong length): {ranking}" | |
if len(set(ranking)) != len(ranking): | |
return f"Invalid ranking (duplicates): {ranking}" | |
if set(ranking) != set(row["mini_id"] for row in rows): | |
return f"Invalid ranking (wrong ids): {ranking}" | |
return None | |
def get_ranking( | |
client: OpenAI, | |
rows: list[dict], | |
model: str = MODEL, | |
temperature: float = 0.0, | |
num_attempts: int = 2, | |
) -> Optional[list[str]]: | |
"""Simplified get_ranking with single attempt and better error handling""" | |
if not rows: | |
return rows | |
user_content = "\n".join(_qa(row=row) for row in rows) | |
messages = [ | |
{"role": "system", "content": SYSTEM_MESSAGE}, | |
{ | |
"role": "user", | |
"content": f"<QuestionAnswerPairs>\n{user_content}\n</QuestionAnswerPairs>", | |
}, | |
] | |
schema = { | |
"type": "object", | |
"properties": { | |
"ranking": { | |
"type": "array", | |
"items": {"type": "string"}, | |
"description": "The ranking of the QAPairs, from best to worst.", | |
}, | |
}, | |
"required": ["ranking"], | |
"additionalProperties": False, | |
} | |
for _ in range(num_attempts): | |
content = None | |
try: | |
response = client.chat.completions.create( | |
model=model, | |
temperature=temperature, | |
messages=messages, | |
response_format={ | |
"type": "json_schema", | |
"json_schema": { | |
"name": "ranking", | |
"strict": True, | |
"schema": schema, | |
}, | |
}, | |
) | |
content = response.choices[0].message.content.strip() | |
ranking = json.loads(content).get("ranking", []) or [] | |
validation_error = _validate_ranking(ranking, rows) | |
if not validation_error: | |
return ranking | |
messages.append({"role": "assistant", "content": content}) | |
messages.append( | |
{ | |
"role": "user", | |
"content": f"<ValidationError>\n{validation_error}\n</ValidationError>\n<MessageToRanker>\nPlease try again (to rank the given QA pairs based on the provided criteria) but correct the validation error\n</MessageToRanker>\n<RepeatedQuestion>\n<QuestionAnswerPairs>\n{user_content}\n</QuestionAnswerPairs>\n</RepeatedQuestion>", | |
} | |
) | |
except (SystemExit, KeyboardInterrupt): | |
raise | |
except Exception as e: | |
print(f"WARNING: {e}") | |
if content: | |
messages.append({"role": "assistant", "content": content}) | |
messages.append( | |
{ | |
"role": "user", | |
"content": f"ERROR: {str(e)}", | |
} | |
) | |
else: | |
raise ValueError( | |
f"Failed to get valid ranking ({num_attempts}) - INFERENCE_ERROR" | |
) | |
raise ValueError(f"Failed to get valid ranking after {num_attempts} attempts") | |
def update_ratings(winner: Player, loser: Player, k: float = 0.5): | |
"""Update ratings using Bayesian update with dynamic K-factor""" | |
# Calculate dynamic K-factor based on number of comparisons | |
k_factor = k / math.sqrt(max(1, min(winner.num_comparisons, loser.num_comparisons))) | |
expected_win = 1.0 / ( | |
1.0 | |
+ math.exp( | |
(loser.mu - winner.mu) / math.sqrt(2.0 * (winner.sigma**2 + loser.sigma**2)) | |
) | |
) | |
# Update step sizes based on uncertainty | |
winner_step = k_factor * winner.sigma | |
loser_step = k_factor * loser.sigma | |
# Update means | |
winner.mu += winner_step * (1 - expected_win) | |
loser.mu -= loser_step * (1 - expected_win) | |
# Update standard deviations | |
winner.sigma *= math.sqrt( | |
max(0.95, 1.0 - k_factor * expected_win * (1 - expected_win)) | |
) | |
loser.sigma *= math.sqrt( | |
max(0.95, 1.0 - k_factor * expected_win * (1 - expected_win)) | |
) | |
# Increment comparison counts | |
winner.num_comparisons += 1 | |
loser.num_comparisons += 1 | |
def rank_question( | |
client: OpenAI, | |
rows: list[dict], | |
model: str = MODEL, | |
group_size: int = 3, | |
num_groups: Optional[int] = None, | |
) -> Optional[list[dict]]: | |
"""Main ranking function using Bayesian skill ratings with comparison tracking""" | |
if not rows: | |
return rows | |
num_items = len(rows) | |
# Calculate optimal number of groups | |
if num_groups is None: | |
try: | |
# We need roughly n*log(n) comparisons total for good ranking | |
target_comparisons = int(num_items * math.log2(num_items)) | |
# Each group of size k generates k(k-1)/2 comparisons | |
comparisons_per_group = (group_size * (group_size - 1)) / 2 | |
# Calculate groups needed to get close to target comparisons | |
num_groups = max( | |
2, # Minimum of 2 groups | |
min( | |
int(target_comparisons / comparisons_per_group), | |
2 * math.ceil(num_items / group_size), # Cap at 2x coverage | |
), | |
) | |
except (ValueError, ZeroDivisionError): | |
num_groups = max(2, math.ceil(num_items / max(1, group_size))) | |
# Initialize players | |
players = {row["mini_id"]: Player(row["mini_id"]) for row in rows} | |
compared_pairs = set() # Track all compared pairs | |
# Generate and process groups | |
for _ in tqdm(range(num_groups), desc="Ranking groups"): | |
# Select items preferring those with fewer comparisons | |
weights = [1.0 / (players[row["mini_id"]].num_comparisons + 1) for row in rows] | |
group = random.choices(rows, weights=weights, k=min(group_size, len(rows))) | |
# Skip if all pairs in this group have been compared | |
group_ids = tuple(sorted(row["mini_id"] for row in group)) | |
if all( | |
(a, b) in compared_pairs | |
for i, a in enumerate(group_ids) | |
for b in group_ids[i + 1 :] | |
): | |
continue | |
try: | |
ranking = get_ranking(client=client, rows=group, model=model) | |
except ValueError as e: | |
print(f"WARN: {e}") | |
# If ranking fails, skip this group | |
continue | |
if ranking is None: | |
continue | |
# Update ratings based on ranking | |
for i in range(len(ranking)): | |
for j in range(i + 1, len(ranking)): | |
winner_id, loser_id = ranking[i], ranking[j] | |
if (winner_id, loser_id) in compared_pairs: | |
continue | |
compared_pairs.add((winner_id, loser_id)) | |
try: | |
update_ratings(players[winner_id], players[loser_id]) | |
except (KeyboardInterrupt, SystemExit): | |
raise | |
except Exception as e: | |
print(f"Error updating ratings: {e}") | |
continue | |
# Sort rows based on final ratings | |
sorted_ids = sorted(players.keys(), key=lambda x: players[x].rating, reverse=True) | |
id_to_row = {row["mini_id"]: row for row in rows} | |
ranked_rows = [id_to_row[id_str] for id_str in sorted_ids] | |
return ranked_rows |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
python src/minicot/train.py \ | |
python train_multi.py \ | |
accelerate launch --num_processes 7 train_multi.py \ | |
--model_name_or_path meta-llama/Llama-3.2-3B-Instruct \ | |
--dataset_name merged_training_data.jsonl \ | |
--learning_rate 1e-4 \ | |
--num_train_epochs 1 \ | |
--per_device_train_batch_size 1 \ | |
--per_device_eval_batch_size 1 \ | |
--gradient_accumulation_steps 8 \ | |
--gradient_checkpointing \ | |
--logging_steps 1 \ | |
--warmup_ratio 0.1 \ | |
--eval_strategy steps \ | |
--eval_steps 20 \ | |
--torch_dtype bfloat16 \ | |
--max_seq_length 16384 \ | |
--use_liger \ | |
--attn_implementation flash_attention_2 \ | |
--optim adamw_bnb_8bit \ | |
--output_dir Llama-3.2-3B-COTv2.3 | |
""" | |
import random | |
from datasets import load_dataset, Dataset, DatasetDict | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from trl import ( | |
ModelConfig, | |
ScriptArguments, | |
SFTConfig, | |
SFTTrainer, | |
TrlParser, | |
get_kbit_device_map, | |
get_peft_config, | |
get_quantization_config, | |
setup_chat_format, | |
) | |
if __name__ == "__main__": | |
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) | |
script_args, training_args, model_config = parser.parse_args_and_config() | |
################ | |
# Model init kwargs & Tokenizer | |
################ | |
quantization_config = get_quantization_config(model_config) | |
model_kwargs = dict( | |
revision=model_config.model_revision, | |
trust_remote_code=model_config.trust_remote_code, | |
attn_implementation=model_config.attn_implementation, | |
torch_dtype=model_config.torch_dtype, | |
use_cache=False if training_args.gradient_checkpointing else True, | |
device_map=get_kbit_device_map() if quantization_config is not None else None, | |
quantization_config=quantization_config, | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_config.model_name_or_path, | |
trust_remote_code=model_config.trust_remote_code, | |
use_fast=True, | |
) | |
tokenizer.pad_token = tokenizer.eos_token | |
if tokenizer.chat_template is None: | |
model = AutoModelForCausalLM.from_pretrained( | |
model_config.model_name_or_path, | |
**model_kwargs, | |
) | |
model, tokenizer = setup_chat_format(model, tokenizer, format="chatml") | |
else: | |
model = model_config.model_name_or_path | |
################ | |
# Dataset | |
################ | |
# max_group_size = 16 | |
# num_repeats = 1 | |
# dataset = load_dataset("json", data_files=script_args.dataset_name) | |
# rows = [dict(r) for r in dataset["train"]] | |
# repeated_rows = [] | |
# for _ in range(num_repeats): | |
# random.shuffle(rows) | |
# repeated_rows.extend(rows) | |
# grouped = [] | |
# while repeated_rows: | |
# row_group_size = random.randint(1, max_group_size) | |
# row_group = repeated_rows[:row_group_size] | |
# repeated_rows = repeated_rows[row_group_size:] | |
# msgs = [] | |
# for row in row_group: | |
# msgs.extend(row["messages"]) | |
# grouped.append({"messages": msgs}) | |
# dataset = Dataset.from_list(grouped).train_test_split(test_size=0.05) | |
dataset = load_dataset("json", data_files=script_args.dataset_name)[ | |
"train" | |
].train_test_split(test_size=0.05) | |
################ | |
# Training | |
################ | |
trainer = SFTTrainer( | |
model=model, | |
args=training_args, | |
train_dataset=dataset[script_args.dataset_train_split], | |
eval_dataset=( | |
dataset[script_args.dataset_test_split] | |
if training_args.eval_strategy != "no" | |
else None | |
), | |
processing_class=tokenizer, | |
peft_config=get_peft_config(model_config), | |
) | |
trainer.train() | |
# Save and push to hub | |
trainer.save_model(training_args.output_dir) | |
if training_args.push_to_hub: | |
trainer.push_to_hub(dataset_name=script_args.dataset_name) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment