Last active
September 20, 2024 20:02
-
-
Save Tostino/b2e993cbda7e1c9d1905acb880baa731 to your computer and use it in GitHub Desktop.
Teach model to decompose word-level tokens into the constituant characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import random | |
import string | |
from typing import List, Dict, Tuple, Callable, Optional, Any | |
from dataclasses import dataclass, field | |
import logging | |
from collections import defaultdict, Counter | |
from typing import Tuple | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import numpy as np | |
from tqdm import tqdm | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
@dataclass | |
class Language: | |
name: str | |
dictionary: List[str] = field(default_factory=list) | |
weight: float = 1.0 | |
@dataclass | |
class SeparatorConfig: | |
default: str = '→' | |
options: List[str] = field(default_factory=lambda: ['→', '\n', '·', '►']) | |
explicit_ratio: float = 0.7 | |
word_separators: List[str] = field(default_factory=lambda: [' ', '-', '_', '']) | |
word_separator_ratio: float = 0.3 # Probability of using a non-space word separator | |
@dataclass | |
class Config: | |
total_samples = 100000 | |
train_test_ratio: float = 0.9 | |
languages = [ | |
("english", 4.0), | |
("german", 1.0), | |
("french", 1.0), | |
("italian", 1.0), | |
("portuguese", 1.0), | |
("spanish", 1.0) | |
] | |
task_weights = {"spelling": 0.2, "char_count": 0.35, "char_substitution": 0.45} | |
separator: SeparatorConfig = field(default_factory=SeparatorConfig) | |
noise_probabilities: Dict[str, float] = field(default_factory=lambda: { | |
'capitalize_first': 0.15, | |
'capitalize_random': 0.10, | |
'capitalize_all': 0.10, | |
'lowercase': 0.65 | |
}) | |
system_prompts: Dict[str, List[str]] = field(default_factory=lambda: { | |
"english": [ | |
"You are a helpful assistant that performs language-related tasks.", | |
"You're an AI trained to analyze and process text in various languages.", | |
"As a text processor, you can handle various language tasks.", | |
"Your role is to assist with linguistic analysis and text manipulation.", | |
"" | |
], | |
"german": [ | |
"Sie sind ein hilfreicher Assistent, der sprachbezogene Aufgaben ausführt.", | |
"Sie sind eine KI, die darauf trainiert wurde, Texte in verschiedenen Sprachen zu analysieren und zu verarbeiten.", | |
"Als Textprozessor können Sie verschiedene Sprachaufgaben bewältigen.", | |
"Ihre Aufgabe ist es, bei der linguistischen Analyse und Textmanipulation zu unterstützen.", | |
"" | |
], | |
"french": [ | |
"Vous êtes un assistant utile qui effectue des tâches liées au langage.", | |
"Vous êtes une IA formée pour analyser et traiter du texte dans diverses langues.", | |
"En tant que processeur de texte, vous pouvez gérer diverses tâches linguistiques.", | |
"Votre rôle est d'aider à l'analyse linguistique et à la manipulation de texte.", | |
"" | |
], | |
"italian": [ | |
"Sei un assistente utile che esegue compiti legati al linguaggio.", | |
"Sei un'IA addestrata per analizzare ed elaborare testi in varie lingue.", | |
"Come elaboratore di testi, puoi gestire vari compiti linguistici.", | |
"Il tuo ruolo è assistere nell'analisi linguistica e nella manipolazione del testo.", | |
"" | |
], | |
"portuguese": [ | |
"Você é um assistente útil que realiza tarefas relacionadas à linguagem.", | |
"Você é uma IA treinada para analisar e processar texto em várias línguas.", | |
"Como processador de texto, você pode lidar com várias tarefas linguísticas.", | |
"Seu papel é auxiliar na análise linguística e manipulação de texto.", | |
"" | |
], | |
"spanish": [ | |
"Eres un asistente útil que realiza tareas relacionadas con el lenguaje.", | |
"Eres una IA entrenada para analizar y procesar texto en varios idiomas.", | |
"Como procesador de texto, puedes manejar varias tareas lingüísticas.", | |
"Tu función es ayudar con el análisis lingüístico y la manipulación de texto.", | |
"" | |
] | |
}) | |
@dataclass | |
class Task: | |
name: str | |
input_func: Callable[[List[str], str], Any] | |
output_func: Callable[[Any, str, str, Config], str] | |
instructions: Dict[str, List[str]] | |
separator_instructions: Dict[str, str] = field(default_factory=dict) | |
class DiverseDatasetGenerator: | |
def __init__(self, base_generator, model_name='all-MiniLM-L6-v2', similarity_threshold=0.95, batch_size=1000): | |
self.base_generator = base_generator | |
self.similarity_threshold = similarity_threshold | |
self.batch_size = batch_size | |
if self.similarity_threshold < 1.0: | |
self.model = SentenceTransformer(model_name) | |
self.index = None | |
self.vectors = [] | |
else: | |
self.model = None | |
self.index = None | |
self.vectors = None | |
def generate_diverse_dataset(self): | |
train_set, test_set = [], [] | |
batch = [] | |
total_samples = self.base_generator.config.total_samples | |
pbar = tqdm(total=total_samples, desc="Generating samples") | |
while len(train_set) + len(test_set) < total_samples: | |
sample = self.base_generator.generate_single_sample() | |
batch.append(sample) | |
if len(batch) >= self.batch_size or len(train_set) + len(test_set) + len(batch) >= total_samples: | |
diverse_batch = self.process_batch(batch) | |
for sample in diverse_batch: | |
if len(train_set) < int(total_samples * self.base_generator.config.train_test_ratio): | |
train_set.append(sample) | |
else: | |
test_set.append(sample) | |
pbar.update(len(diverse_batch)) | |
logging.info(f"Added {len(diverse_batch)} samples. Total: {len(train_set) + len(test_set)}") | |
batch = [] | |
pbar.close() | |
logging.info(f"Final dataset size: Train {len(train_set)}, Test {len(test_set)}") | |
return train_set, test_set | |
def process_batch(self, batch): | |
if self.similarity_threshold == 1.0: | |
return batch # All samples are considered diverse when threshold is 1.0 | |
texts = [' '.join([msg['content'] for msg in sample['messages']]) for sample in batch] | |
vectors = self.model.encode(texts) | |
if self.index is None: | |
self.index = faiss.IndexFlatIP(vectors.shape[1]) | |
diverse_batch = [] | |
for i, vector in enumerate(vectors): | |
if self.is_diverse(vector): | |
self.add_to_index(vector) | |
diverse_batch.append(batch[i]) | |
logging.info(f"Processed batch: {len(batch)} samples, {len(diverse_batch)} diverse samples") | |
return diverse_batch | |
def is_diverse(self, vector): | |
if self.similarity_threshold == 1.0: | |
return True # All samples are considered diverse when threshold is 1.0 | |
if self.index.ntotal == 0: | |
return True | |
D, _ = self.index.search(np.array([vector]).astype('float32'), 1) | |
similarity = D[0][0] # Cosine similarity | |
return similarity < self.similarity_threshold | |
def add_to_index(self, vector): | |
if self.similarity_threshold < 1.0: | |
self.index.add(np.array([vector]).astype('float32')) | |
class DatasetGenerator: | |
def __init__(self, languages: List[Language], tasks: List[Task], config: Config): | |
self.languages = [lang for lang in languages if lang.dictionary] | |
self.tasks = tasks | |
self.config = config | |
self.normalize_language_weights() | |
self.system_prompts: Dict[str, List[str]] = self.config.system_prompts | |
def normalize_language_weights(self) -> None: | |
"""Normalize language weights to ensure they sum to 1.""" | |
total_weight = sum(lang.weight for lang in self.languages) | |
for lang in self.languages: | |
lang.weight /= total_weight | |
def apply_noise(self, word: str) -> str: | |
"""Apply random noise to the input word based on configured probabilities.""" | |
noise_type = random.choices( | |
list(self.config.noise_probabilities.keys()), | |
weights=list(self.config.noise_probabilities.values()) | |
)[0] | |
if noise_type == 'capitalize_first': | |
return word.capitalize() | |
elif noise_type == 'capitalize_random': | |
index = random.randint(0, len(word) - 1) | |
return word[:index] + word[index].upper() + word[index+1:] | |
elif noise_type == 'capitalize_all': | |
return word.upper() | |
else: | |
return word.lower() | |
def apply_word_separator(self, words: List[str]) -> str: | |
if random.random() < self.config.separator.word_separator_ratio: | |
separator = random.choice(self.config.separator.word_separators) | |
else: | |
separator = ' ' | |
return separator.join(words) | |
def generate_single_sample(self) -> Dict: | |
language = random.choices(self.languages, weights=[lang.weight for lang in self.languages])[0] | |
task = random.choices(self.tasks, weights=[self.config.task_weights[task.name] for task in self.tasks])[0] | |
# Randomly choose to operate on 1, 2, or 3 words | |
num_words = random.choices([1, 2, 3], weights=[0.6, 0.3, 0.1])[0] | |
words = [random.choice(language.dictionary) for _ in range(num_words)] | |
noisy_words = [self.apply_noise(word) for word in words] | |
combined_words = self.apply_word_separator(noisy_words) | |
use_explicit_separator = random.random() < self.config.separator.explicit_ratio | |
separator = random.choice(self.config.separator.options) if use_explicit_separator else self.config.separator.default | |
input_data = task.input_func(noisy_words, combined_words) | |
output_data = task.output_func(input_data, language.name, separator, self.config) | |
if task.name == "spelling": | |
formatted_instruction = random.choice(task.instructions[language.name]).format(words=combined_words) | |
elif task.name == "char_count": | |
char = input_data[2] | |
formatted_instruction = random.choice(task.instructions[language.name]).format(words=combined_words, char=char) | |
elif task.name == "char_substitution": | |
old_char, new_char = input_data[1], input_data[2] | |
formatted_instruction = random.choice(task.instructions[language.name]).format(words=combined_words, old_char=old_char, new_char=new_char) | |
else: | |
raise ValueError(f"Unknown task: {task.name}") | |
if use_explicit_separator: | |
separator_instruction = task.separator_instructions.get(language.name, "").format(separator=separator) | |
formatted_instruction = f"{formatted_instruction} {separator_instruction}" | |
system_message = random.choice(self.system_prompts[language.name]) | |
sample = { | |
"messages": [ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": formatted_instruction}, | |
{"role": "assistant", "content": output_data} | |
], | |
"language": language.name, | |
"task": task.name | |
} | |
return sample | |
def generate_dataset(self) -> Tuple[List[Dict], List[Dict]]: | |
dataset = [] | |
for _ in range(self.config.total_samples): | |
sample = self.generate_single_sample() | |
dataset.append(sample) | |
random.shuffle(dataset) | |
split_index = int(len(dataset) * self.config.train_test_ratio) | |
train_set, test_set = dataset[:split_index], dataset[split_index:] | |
return train_set, test_set | |
# End of classes | |
def load_dictionary(lang: str) -> List[str]: | |
""" | |
Load a dictionary file for the given language. | |
Supports both text files (one word per line) and JSON files (array of words). | |
""" | |
file_name = f"{lang}_dictionary" | |
# Try loading JSON file first | |
try: | |
with open(f"{file_name}.json", "r", encoding="utf-8") as f: | |
data = json.load(f) | |
if isinstance(data, list): | |
return [word.strip() for word in data if word.strip()] | |
except FileNotFoundError: | |
pass | |
except json.JSONDecodeError: | |
logging.warning(f"Error decoding JSON file for {lang}.") | |
# If JSON file is not found or invalid, try loading text file | |
try: | |
with open(f"{file_name}.txt", "r", encoding="utf-8") as f: | |
return [line.strip() for line in f if line.strip()] | |
except FileNotFoundError: | |
pass | |
logging.warning(f"Dictionary file for {lang} not found in either JSON or text format. Skipping this language.") | |
return [] | |
# Task definitions | |
def spelling_input(words: List[str], combined_words: str) -> Tuple[List[str], str]: | |
return words, combined_words | |
def spelling_output(input_data: Tuple[List[str], str], language: str, separator: str, config: Config) -> str: | |
words, _ = input_data # We'll ignore the original combined_words and create our own | |
if random.random() < config.separator.word_separator_ratio: | |
word_separator = random.choice(config.separator.word_separators) | |
else: | |
word_separator = ' ' | |
# Create combined_words with the chosen word_separator | |
combined_words = word_separator.join(words) | |
# Spell out each word, including the word separator | |
spelled_words = [] | |
for i, word in enumerate(words): | |
spelled_word = separator.join(word) | |
spelled_words.append(spelled_word) | |
if i < len(words) - 1: # Add word separator after each word except the last | |
spelled_words.append(separator.join(word_separator)) | |
# Join all spelled components | |
spelled_output = separator.join(spelled_words) | |
responses = { | |
"english": lambda o, s: f"Original: {o}\nSpelled: {s}", | |
"german": lambda o, s: f"Original: {o}\nBuchstabiert: {s}", | |
"french": lambda o, s: f"Original : {o}\nÉpelé : {s}", | |
"italian": lambda o, s: f"Originale: {o}\nCompitato: {s}", | |
"portuguese": lambda o, s: f"Original: {o}\nSoletrado: {s}", | |
"spanish": lambda o, s: f"Original: {o}\nDeletreado: {s}" | |
} | |
return responses.get(language, responses["english"])(combined_words, spelled_output) | |
def char_count_input(words: List[str], combined_words: str) -> Tuple[List[str], str, str]: | |
combined_words_lower = combined_words.lower() | |
word_char_counts = Counter(combined_words_lower) | |
all_chars = set(string.ascii_lowercase + string.ascii_uppercase + string.punctuation + string.digits) | |
non_word_chars = all_chars - set(combined_words_lower) | |
if random.random() < 0.75 and word_char_counts: # 75% chance to choose a character from the word | |
# Group characters by their count | |
count_groups = {} | |
for char, count in word_char_counts.items(): | |
count_groups.setdefault(count, []).append(char) | |
# Choose a count group, weighted towards less frequent counts | |
weights = [1/count for count in count_groups.keys()] | |
chosen_count = random.choices(list(count_groups.keys()), weights=weights)[0] | |
# Choose a random character from the selected count group | |
char = random.choice(count_groups[chosen_count]) | |
else: # 25% chance to choose a character not in the word | |
char = random.choice(list(non_word_chars)) if non_word_chars else random.choice(list(all_chars)) | |
return words, combined_words, char | |
def char_count_output(input_data: Tuple[str, str, str], language: str, separator: str, config: Config) -> str: | |
words, combined_words, char = input_data | |
count = combined_words.lower().count(char.lower()) | |
spelled_words = [separator.join(word) for word in words] | |
if random.random() < config.separator.word_separator_ratio: | |
word_separator = random.choice(config.separator.word_separators) | |
else: | |
word_separator = ' ' | |
# Apply the chosen separator around the word_separator | |
separated_word_separator = f"{separator}{word_separator}{separator}" | |
spelled_output = separated_word_separator.join(spelled_words) | |
intro_and_responses = { | |
"english": ( | |
"Character count analysis:", | |
lambda so, c, n: f"Spelled words: {so}. The character '{c}' appears {n} time(s) in total." | |
), | |
"german": ( | |
"Zeichenzählanalyse:", | |
lambda so, c, n: f"Buchstabierte Wörter: {so}. Der Buchstabe '{c}' kommt insgesamt {n} Mal vor." | |
), | |
"french": ( | |
"Analyse du nombre de caractères :", | |
lambda so, c, n: f"Mots épelés : {so}. Le caractère '{c}' apparaît {n} fois au total." | |
), | |
"italian": ( | |
"Analisi del conteggio dei caratteri:", | |
lambda so, c, n: f"Parole compitate: {so}. Il carattere '{c}' appare {n} volta/e in totale." | |
), | |
"portuguese": ( | |
"Análise de contagem de caracteres:", | |
lambda so, c, n: f"Palavras soletradas: {so}. O caractere '{c}' aparece {n} vez(es) no total." | |
), | |
"spanish": ( | |
"Análisis de recuento de caracteres:", | |
lambda so, c, n: f"Palabras deletreadas: {so}. El carácter '{c}' aparece {n} vez/veces en total." | |
) | |
} | |
intro, response_func = intro_and_responses.get(language, intro_and_responses["english"]) | |
response = response_func(spelled_output, char, count) | |
return f"{intro} {response}" | |
def char_substitution_input(words: List[str], combined_words: str) -> Tuple[List[str], str, str]: | |
# Combine all words to get the full character set | |
all_chars = set(''.join(words).lower()) | |
# Choose old_char from the actual characters in the words | |
old_char = random.choice(list(all_chars)) | |
# Ensure new_char is different from old_char (ignoring case) | |
while True: | |
new_char = random.choice(string.ascii_letters) | |
if new_char.lower() != old_char.lower(): | |
break | |
return words, old_char, new_char | |
def char_substitution_output(input_data: Tuple[List[str], str, str], language: str, separator: str, config: Config) -> str: | |
words, old_char, new_char = input_data | |
original_spelled = [] | |
results = [] | |
final_results = [] | |
for word in words: | |
# Spell out the original word | |
original_spelled.append(separator.join(word)) | |
result = "" | |
for char in word: | |
if char.lower() == old_char.lower(): | |
result += new_char.upper() if char.isupper() else new_char.lower() | |
else: | |
result += char | |
results.append(separator.join(result)) | |
final_results.append(result) | |
# Use word_separator_ratio and word_separators for joining words | |
if random.random() < config.separator.word_separator_ratio: | |
word_separator = random.choice(config.separator.word_separators) | |
else: | |
word_separator = ' ' | |
# Apply the chosen separator around the word_separator | |
separated_word_separator = f"{separator}{word_separator}{separator}" | |
original = separated_word_separator.join(original_spelled) | |
modified = separated_word_separator.join(results) | |
final_result = word_separator.join(final_results) | |
responses = { | |
"english": lambda o, m, r, oc, nc: f"Spelled out: {o}\nReplacing '{oc}' with '{nc}'\nSpelled out result: {m}\nFinal result: {r}", | |
"german": lambda o, m, r, oc, nc: f"Buchstabiert: {o}\nErsetze '{oc}' durch '{nc}'\nBuchstabiertes Ergebnis: {m}\nEndgültiges Ergebnis: {r}", | |
"french": lambda o, m, r, oc, nc: f"Épelé : {o}\nRemplacement de '{oc}' par '{nc}'\nRésultat épelé : {m}\nRésultat final : {r}", | |
"italian": lambda o, m, r, oc, nc: f"Compitato: {o}\nSostituzione di '{oc}' con '{nc}'\nRisultato compitato: {m}\nRisultato finale: {r}", | |
"portuguese": lambda o, m, r, oc, nc: f"Soletrado: {o}\nSubstituindo '{oc}' por '{nc}'\nResultado soletrado: {m}\nResultado final: {r}", | |
"spanish": lambda o, m, r, oc, nc: f"Deletreado: {o}\nReemplazando '{oc}' por '{nc}'\nResultado deletreado: {m}\nResultado final: {r}" | |
} | |
return responses.get(language, responses["english"])(original, modified, final_result, old_char, new_char) | |
def save_dataset(dataset: List[Dict], filename: str) -> None: | |
try: | |
with open(filename, "w", encoding="utf-8") as f: | |
for item in dataset: | |
json.dump(item, f, ensure_ascii=False) | |
f.write('\n') | |
print(f"Dataset saved to '{filename}'") | |
except IOError as e: | |
logging.error(f"Error saving dataset to '{filename}': {e}") | |
def main(): | |
spelling_instructions = { | |
"english": [ | |
"Could you spell out the words `{words}` for me?", | |
"How do you spell '{words}'?", | |
"Please give me the spelling of '{words}'?", | |
"`{words}` - can you break these down letter by letter?", | |
"Give me the spelled out version of the words: {words}" | |
], | |
"german": [ | |
"Können Sie die Wörter `{words}` für mich buchstabieren?", | |
"Wie buchstabiert man '{words}'?", | |
"Bitte geben Sie mir die Buchstabierung von '{words}'.", | |
"`{words}` - können Sie diese Buchstabe für Buchstabe aufschlüsseln?", | |
"Nennen Sie mir die ausgeschriebene Version der Wörter: {words}" | |
], | |
"french": [ | |
"Pourriez-vous épeler les mots `{words}` pour moi ?", | |
"Comment écrit-on '{words}' ?", | |
"Veuillez me donner l'orthographe de '{words}'.", | |
"`{words}` - pouvez-vous les décomposer lettre par lettre ?", | |
"Donnez-moi la version épelée des mots : {words}" | |
], | |
"italian": [ | |
"Potresti sillabare le parole `{words}` per me?", | |
"Come si scrivono '{words}'?", | |
"Per favore, dammi lo spelling di '{words}'.", | |
"`{words}` - puoi scomporle lettera per lettera?", | |
"Dammi la versione sillabata delle parole: {words}" | |
], | |
"portuguese": [ | |
"Você poderia soletrar as palavras `{words}` para mim?", | |
"Como se soletra '{words}'?", | |
"Por favor, dê-me a soletração de '{words}'.", | |
"`{words}` - você pode decompor isso letra por letra?", | |
"Dê-me a versão soletrada das palavras: {words}" | |
], | |
"spanish": [ | |
"¿Podrías deletrear las palabras `{words}` para mí?", | |
"¿Cómo se deletrean '{words}'?", | |
"Por favor, dame el deletreo de '{words}'.", | |
"`{words}` - ¿puedes descomponerlas letra por letra?", | |
"Dame la versión deletreada de las palabras: {words}" | |
] | |
} | |
separator_instructions = { | |
"english": "Use '{separator}' as a separator between letters.", | |
"german": "Verwenden Sie '{separator}' als Trennzeichen zwischen den Buchstaben.", | |
"french": "Utilisez '{separator}' comme séparateur entre les lettres.", | |
"italian": "Usa '{separator}' come separatore tra le lettere.", | |
"portuguese": "Use '{separator}' como separador entre as letras.", | |
"spanish": "Usa '{separator}' como separador entre las letras." | |
} | |
char_count_instructions = { | |
"english": [ | |
"In the words: {words}, how many times does the letter: {char} appear in total?", | |
"Count the occurrences of '{char}' in '{words}'. How many are there in total?", | |
"'{words}' contain how many instances of the character '{char}' altogether?", | |
"If we look at `{words}`, what's the total frequency of `{char}`?", | |
"How often does `{char}` show up when you spell out `{words}`?" | |
], | |
"german": [ | |
"In den Wörtern: {words}, wie oft kommt der Buchstabe: {char} insgesamt vor?", | |
"Zählen Sie die Vorkommen von '{char}' in '{words}'. Wie viele sind es insgesamt?", | |
"'{words}' enthalten wie viele Instanzen des Zeichens '{char}' insgesamt?", | |
"Wenn wir uns `{words}` ansehen, wie häufig ist `{char}` insgesamt?", | |
"Wie oft taucht `{char}` auf, wenn Sie `{words}` buchstabieren?" | |
], | |
"french": [ | |
"Dans les mots : {words}, combien de fois apparaît la lettre : {char} au total ?", | |
"Comptez les occurrences de '{char}' dans '{words}'. Combien y en a-t-il au total ?", | |
"'{words}' contiennent combien d'instances du caractère '{char}' en tout ?", | |
"Si on regarde `{words}`, quelle est la fréquence totale de `{char}` ?", | |
"Combien de fois `{char}` apparaît-il quand vous épelez `{words}` ?" | |
], | |
"italian": [ | |
"Nelle parole: {words}, quante volte appare la lettera: {char} in totale?", | |
"Conta le occorrenze di '{char}' in '{words}'. Quante ce ne sono in totale?", | |
"'{words}' contengono quante istanze del carattere '{char}' in tutto?", | |
"Se guardiamo `{words}`, qual è la frequenza totale di `{char}`?", | |
"Quante volte appare `{char}` quando si sillabano `{words}`?" | |
], | |
"portuguese": [ | |
"Nas palavras: {words}, quantas vezes a letra: {char} aparece no total?", | |
"Conte as ocorrências de '{char}' em '{words}'. Quantas existem no total?", | |
"'{words}' contêm quantas instâncias do caractere '{char}' no total?", | |
"Se olharmos para `{words}`, qual é a frequência total de `{char}`?", | |
"Com que frequência `{char}` aparece quando você soletra `{words}`?" | |
], | |
"spanish": [ | |
"En las palabras: {words}, ¿cuántas veces aparece la letra: {char} en total?", | |
"Cuenta las apariciones de '{char}' en '{words}'. ¿Cuántas hay en total?", | |
"'{words}' contienen ¿cuántas instancias del carácter '{char}' en total?", | |
"Si miramos `{words}`, ¿cuál es la frecuencia total de `{char}`?", | |
"¿Con qué frecuencia aparece `{char}` cuando deletreas `{words}`?" | |
] | |
} | |
char_substitution_instructions = { | |
"english": [ | |
"In the words '{words}', replace every '{old_char}' with '{new_char}'. How do they look?", | |
"Substitute '{old_char}' with '{new_char}' in '{words}'. What's the result?", | |
"Change all instances of '{old_char}' to '{new_char}' in '{words}'. What do you get?", | |
"If we swap '{old_char}' for '{new_char}' in '{words}', what's the outcome?", | |
"Transform '{words}' by replacing '{old_char}' with '{new_char}'. What are the new words?" | |
], | |
"german": [ | |
"Ersetzen Sie in den Wörtern '{words}' jedes '{old_char}' durch '{new_char}'. Wie sehen sie aus?", | |
"Tauschen Sie '{old_char}' mit '{new_char}' in '{words}' aus. Was ist das Ergebnis?", | |
"Ändern Sie alle Vorkommen von '{old_char}' zu '{new_char}' in '{words}'. Was erhalten Sie?", | |
"Wenn wir '{old_char}' durch '{new_char}' in '{words}' ersetzen, was ist das Resultat?", | |
"Transformieren Sie '{words}', indem Sie '{old_char}' durch '{new_char}' ersetzen. Wie lauten die neuen Wörter?" | |
], | |
"french": [ | |
"Dans les mots '{words}', remplacez chaque '{old_char}' par '{new_char}'. À quoi ressemblent-ils ?", | |
"Substituez '{old_char}' par '{new_char}' dans '{words}'. Quel est le résultat ?", | |
"Changez toutes les occurrences de '{old_char}' en '{new_char}' dans '{words}'. Qu'obtenez-vous ?", | |
"Si on échange '{old_char}' contre '{new_char}' dans '{words}', quel est le résultat ?", | |
"Transformez '{words}' en remplaçant '{old_char}' par '{new_char}'. Quels sont les nouveaux mots ?" | |
], | |
"italian": [ | |
"Nelle parole '{words}', sostituisci ogni '{old_char}' con '{new_char}'. Come appaiono?", | |
"Sostituisci '{old_char}' con '{new_char}' in '{words}'. Qual è il risultato?", | |
"Cambia tutte le istanze di '{old_char}' in '{new_char}' in '{words}'. Cosa ottieni?", | |
"Se scambiamo '{old_char}' con '{new_char}' in '{words}', qual è il risultato?", | |
"Trasforma '{words}' sostituendo '{old_char}' con '{new_char}'. Quali sono le nuove parole?" | |
], | |
"portuguese": [ | |
"Nas palavras '{words}', substitua cada '{old_char}' por '{new_char}'. Como ficam?", | |
"Substitua '{old_char}' por '{new_char}' em '{words}'. Qual é o resultado?", | |
"Mude todas as ocorrências de '{old_char}' para '{new_char}' em '{words}'. O que você obtém?", | |
"Se trocarmos '{old_char}' por '{new_char}' em '{words}', qual é o resultado?", | |
"Transforme '{words}' substituindo '{old_char}' por '{new_char}'. Quais são as novas palavras?" | |
], | |
"spanish": [ | |
"En las palabras '{words}', reemplaza cada '{old_char}' por '{new_char}'. ¿Cómo se ven?", | |
"Sustituye '{old_char}' por '{new_char}' en '{words}'. ¿Cuál es el resultado?", | |
"Cambia todas las instancias de '{old_char}' a '{new_char}' en '{words}'. ¿Qué obtienes?", | |
"Si intercambiamos '{old_char}' por '{new_char}' en '{words}', ¿cuál es el resultado?", | |
"Transforma '{words}' reemplazando '{old_char}' por '{new_char}'. ¿Cuáles son las nuevas palabras?" | |
] | |
} | |
# Create Task instances | |
tasks = [ | |
Task("spelling", spelling_input, spelling_output, spelling_instructions, separator_instructions), | |
Task("char_count", char_count_input, char_count_output, char_count_instructions, separator_instructions), | |
Task("char_substitution", char_substitution_input, char_substitution_output, char_substitution_instructions, separator_instructions) | |
] | |
# Create Config instance | |
config = Config() | |
# Create Language instances | |
languages = [ | |
Language(name, load_dictionary(name), weight) for name, weight in config.languages | |
] | |
# Remove languages with empty dictionaries | |
languages = [lang for lang in languages if lang.dictionary] | |
# Create DatasetGenerator instance | |
base_generator = DatasetGenerator(languages, tasks, config) | |
diverse_generator = DiverseDatasetGenerator(base_generator, batch_size=1000, similarity_threshold=0.8) | |
# Generate dataset | |
train_set, test_set = diverse_generator.generate_diverse_dataset() | |
# Print sample results | |
print("Sample training set entries:") | |
for sample in train_set[:3]: | |
print(json.dumps(sample, ensure_ascii=False, indent=2)) | |
print("\nSample testing set entries:") | |
for sample in test_set[:3]: | |
print(json.dumps(sample, ensure_ascii=False, indent=2)) | |
save_dataset(train_set, "multilingual_dataset_train.jsonl") | |
save_dataset(test_set, "multilingual_dataset_test.jsonl") | |
# Print dataset statistics | |
train_count = len(train_set) | |
test_count = len(test_set) | |
total_count = train_count + test_count | |
print(f"\nDataset statistics:") | |
print(f"Total samples: {total_count}") | |
print(f"Training set: {train_count} ({train_count/total_count:.2%})") | |
print(f"Testing set: {test_count} ({test_count/total_count:.2%})") | |
lang_stats = defaultdict(lambda: defaultdict(int)) | |
for dataset in (train_set, test_set): | |
for sample in dataset: | |
lang_stats[sample['language']][sample['task']] += 1 | |
print("\nSamples per language and task:") | |
for lang, tasks in lang_stats.items(): | |
print(f"{lang}:") | |
for task, count in tasks.items(): | |
print(f" {task}: {count}") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment