Created
October 29, 2023 23:57
-
-
Save transmissions11/9db2d65b4d9299a7671acda165d88833 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import openai | |
import random | |
import time | |
import os | |
import orjson | |
import polars as pl | |
from collections import deque | |
from concurrent.futures import ThreadPoolExecutor | |
from config.embed import (api_keys, key_copies, \ | |
num_samples, batch_size, val_percent, \ | |
retries, min_retry_sleep_duration, retry_jitter_range, input_file, out_dir, | |
input_file_csv_column, examples_per_chunk, exclude_incomplete_chunks, batches_for_eta_calc) | |
assert examples_per_chunk > batch_size, "examples_per_chunk must be greater than batch_size." | |
# Repeat each key KEY_COPIES times. | |
api_keys = api_keys * key_copies | |
# Create the checkpoints folder if it doesn't exist. | |
os.makedirs(out_dir, exist_ok=True) | |
os.makedirs(f"{out_dir}/train", exist_ok=True) | |
os.makedirs(f"{out_dir}/val", exist_ok=True) | |
print("Loading inputs...") | |
if input_file.endswith(".txt"): | |
with open(input_file, "r") as f: | |
raw_lines = [line.strip() for line in f] | |
elif input_file.endswith(".csv"): | |
df = pl.read_csv(input_file) | |
raw_lines = df[input_file_csv_column].to_list() | |
print("Sampling randomly...") | |
inputs = random.sample(raw_lines, num_samples) | |
print("Inputs loaded & sampled.") | |
# Get # of samples for training and validation. | |
val_samples = int(val_percent * num_samples) | |
train_samples = num_samples - val_samples | |
# Split the inputs into training and validation sets. | |
train_inputs = inputs[:train_samples] | |
val_inputs = inputs[train_samples:] | |
# For tracking progress while generating chunks. | |
start_time = time.time() | |
chunks_saved = 0 | |
batches_processed = 0 | |
last_batch_times = deque(maxlen=batches_for_eta_calc) | |
def log(raw_api_key_index, message): | |
print(f"[Key #{raw_api_key_index // key_copies + 1} — Copy #{raw_api_key_index % key_copies + 1}] {message}") | |
def embed_batch(raw_api_key_index, batch_input): | |
api_key = api_keys[raw_api_key_index] | |
for attempt in range(retries): | |
try: | |
data = openai.Embedding.create( | |
api_key=api_key, | |
model="text-embedding-ada-002", | |
input=batch_input, | |
)["data"] | |
return [entry["embedding"] for entry in data] | |
except Exception as e: | |
if attempt == retries - 1: | |
log(raw_api_key_index, | |
f"Batch catastrophically failed after {attempt + 1} attempts." | |
f"Not retrying again. Final error was: {str(e)}") | |
else: | |
log(raw_api_key_index, f"Batch failed with error: {str(e)}, retrying (attempt {attempt + 1})...") | |
jitter = random.uniform(retry_jitter_range[0], retry_jitter_range[1]) | |
time.sleep(min_retry_sleep_duration * jitter) # Sleep for a bit before retrying. | |
def main(): | |
global chunks_saved, batches_processed, start_time, batches_for_eta_calc, last_batch_times | |
for dataset, split in ( | |
(train_inputs, "train"), | |
(val_inputs, "val"), | |
): | |
# Reset progress tracking variables. | |
chunks_saved = 0 | |
batches_processed = 0 | |
last_batch_times.clear() | |
# Calculate the number of sequences each API key will be assigned | |
# to process. We round up to ensure we process the whole dataset. | |
block_size = math.ceil(len(dataset) / len(api_keys)) | |
# Calculate the total number of chunks / batches we'll save in total. We round | |
# up to ensure we account for the last chunk, which may be smaller than the rest. | |
total_chunks = math.ceil(block_size / examples_per_chunk) * len(api_keys) | |
total_batches = math.ceil(block_size / batch_size) * len(api_keys) | |
print(f"Processing {split} dataset...") | |
def run_api_key_worker(raw_api_key_index, split_name, worker_inputs): | |
global chunks_saved, batches_processed, start_time, batches_for_eta_calc, last_batch_times | |
log(raw_api_key_index, f"Starting worker...") | |
chunk = [] # Will be extended with tuples of (input, embedding) pairs as we go. | |
for i in range(0, len(worker_inputs), batch_size): | |
# Get the next batch of inputs. | |
batch_input = worker_inputs[i:i + batch_size] | |
# Skip incomplete batches if requested via config option. | |
if exclude_incomplete_chunks and len(batch_input) < batch_size: | |
log(raw_api_key_index, f"Skipping batch because it's incomplete. Length: {len(batch_input)}") | |
continue | |
# Embed the inputs in the batch. | |
batch_embedding = embed_batch(raw_api_key_index, batch_input) | |
# Extend the chunk with the batch. | |
chunk.extend({"text": t, "embedding": e} for t, e in zip(batch_input, batch_embedding)) | |
# Cache the current batch number, so we can calculate ETA. | |
batch_processed = batches_processed | |
batches_processed += 1 # Update the batches processed counter. | |
# Update the batch times deque (for ETA calculations) | |
last_batch_times.append(time.time()) | |
# If we've processed enough batches to calculate an ETA, do so. | |
if len(last_batch_times) >= batches_for_eta_calc: | |
avg_time_per_batch = (last_batch_times[-1] - last_batch_times[0]) / (len(last_batch_times) - 1) | |
remaining_batches = total_batches - (batch_processed + 1) | |
eta_seconds = avg_time_per_batch * remaining_batches | |
eta_time = "{:02d}:{:02d}:{:02d}".format( | |
int(eta_seconds // 3600), | |
int((eta_seconds % 3600) // 60), | |
int(eta_seconds % 60), | |
) | |
else: | |
eta_time = "N/A" # Otherwise, just say ETA is N/A. | |
log(raw_api_key_index, | |
f"Processed batch {batch_processed + 1}/{total_batches} with {len(batch_input)} entries. " | |
f"ETA: {eta_time}") | |
# If we've reached the end of the chunk, or the last batch, save the chunk. | |
if len(chunk) >= examples_per_chunk or i + batch_size >= len(worker_inputs): | |
chunk_to_write = chunks_saved # Cache the current chunk number. | |
# Update now so another thread doesn't try to save the same chunk. | |
chunks_saved += 1 | |
log(raw_api_key_index, f"Saving chunk {chunk_to_write}...") | |
with open(f'{out_dir}/{split_name}/chunk{chunk_to_write}.json', "wb") as f: | |
f.write(orjson.dumps({"chunk": chunk_to_write, "data": chunk})) | |
log(raw_api_key_index, | |
f"Completed chunk {chunk_to_write + 1}/{total_chunks} with {len(chunk)} entries.") | |
# Reset the chunk. | |
chunk = [] | |
with ThreadPoolExecutor(max_workers=len(api_keys)) as executor: | |
for api_key_i in range(len(api_keys)): | |
executor.submit( | |
run_api_key_worker, | |
api_key_i, | |
split, | |
# Slice the dataset into blocks of size block_size. | |
# It's okay that we round up block_size here, because if we | |
# overshoot the end of the dataset, it will just go to the end. | |
dataset[block_size * api_key_i:block_size * (api_key_i + 1)], | |
) | |
executor.shutdown(wait=True) # Wait for the current dataset to finish before processing the next one. | |
time.sleep(10) # Sleep for good measure. | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment