Created
October 14, 2024 15:08
-
-
Save bizrockman/184a50eefb72c18cf73b99dc30e72a82 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from openai import OpenAI | |
import json | |
import re | |
import hashlib | |
import random | |
from tqdm import tqdm | |
from dotenv import load_dotenv | |
load_dotenv() | |
oai_client = OpenAI() | |
# Verzeichnis mit den TXT-Dateien | |
input_dir = r"<Pfad zu den Text Dateien>" | |
# Liste der TXT-Dateien aus dem Verzeichnis laden | |
txt_files = [f for f in os.listdir(input_dir) if f.endswith('.txt')] | |
# Shuffle für zufällige Verteilung | |
random.shuffle(txt_files) | |
# Aufteilen in Train (80%) und Test (20%) | |
split_index = int(0.8 * len(txt_files)) | |
train_files = txt_files[:split_index] | |
test_files = txt_files[split_index:] | |
# Queries speichern | |
queries = [] | |
# Funktion zum Generieren von Hashes | |
def generate_hash(value): | |
hash_object = hashlib.sha256(value.encode('utf-8')) | |
return hash_object.hexdigest()[:16] | |
# Funktion zum Generieren von Queries und Speichern der Ergebnisse | |
def process_files(file_list, split): | |
split_queries = [] | |
for txt_file in tqdm(file_list): | |
# Titel extrahieren (Dateiname ohne das # und Zahlen danach) | |
title = re.sub(r"#\d+", "", os.path.splitext(txt_file)[0]).strip() | |
# Inhalt der Datei lesen | |
try: | |
with open(os.path.join(input_dir, txt_file), 'r', encoding='utf-8') as file: | |
text = file.read() | |
except UnicodeDecodeError: | |
with open(os.path.join(input_dir, txt_file), 'r', encoding='latin-1') as file: | |
text = file.read() | |
# Prompt für die Generierung der Query | |
prompt = f"Erstelle eine passende Suchanfrage (Query) basierend auf folgendem Text:\nTitle: {title}\nText: {text}\n\nQuery:" | |
# OpenAI API Aufruf zur Generierung der Query | |
response = oai_client.chat.completions.create( | |
messages=[ | |
{ | |
"role": "user", | |
"content": prompt | |
} | |
], | |
model="gpt-4o-mini", # Neustes Modell verwenden | |
max_tokens=100, # Länge der Query begrenzen | |
temperature=0.7 | |
) | |
generated_query = response.choices[0].message.content.replace('"', '').strip() | |
print('\n', generated_query) | |
# Hash für die IDs generieren | |
corpus_id = generate_hash(title) | |
query_id = generate_hash(title + generated_query) | |
# Ergebnis zur Liste hinzufügen | |
split_queries.append({ | |
"query_id": f"q_{query_id}", | |
"query": generated_query, | |
"corpus_id": corpus_id, | |
"score": 1 # Die generierte Query wird als 100% relevant betrachtet | |
}) | |
return split_queries | |
# Train- und Test-Queries generieren | |
train_queries = process_files(train_files, "train") | |
test_queries = process_files(test_files, "test") | |
# Datensatz im BEIR-kompatiblen Format speichern | |
output_dir = "beir_custom_dataset" | |
os.makedirs(output_dir, exist_ok=True) | |
# Corpus speichern (corpus.jsonl) | |
with open(os.path.join(output_dir, "corpus.jsonl"), "w", encoding='utf-8') as corpus_file: | |
for txt_file in txt_files: | |
try: | |
with open(os.path.join(input_dir, txt_file), 'r', encoding='utf-8') as file: | |
text = file.read() | |
except UnicodeDecodeError: | |
with open(os.path.join(input_dir, txt_file), 'r', encoding='latin-1') as file: | |
text = file.read() | |
title = re.sub(r"#\d+", "", os.path.splitext(txt_file)[0]).strip() | |
corpus_id = generate_hash(title) | |
corpus_entry = { | |
"_id": corpus_id, | |
"title": title, | |
"text": text | |
} | |
corpus_file.write(json.dumps(corpus_entry) + "\n") | |
# Queries speichern (queries.jsonl) | |
with open(os.path.join(output_dir, "queries.jsonl"), "w", encoding='utf-8') as queries_file: | |
for query in train_queries + test_queries: | |
queries_file.write(json.dumps({"_id": query["query_id"], "text": query["query"]}) + "\n") | |
# Relevanz-Dateien speichern (qrels/train.tsv und qrels/test.tsv) | |
os.makedirs(os.path.join(output_dir, "qrels"), exist_ok=True) | |
with open(os.path.join(output_dir, "qrels/train.tsv"), "w", encoding='utf-8') as qrels_file: | |
qrels_file.write("query-id\tcorpus-id\tscore\n") | |
for query in train_queries: | |
qrels_file.write(f"{query['query_id']}\t{query['corpus_id']}\t{query['score']}\n") | |
with open(os.path.join(output_dir, "qrels/test.tsv"), "w", encoding='utf-8') as qrels_file: | |
qrels_file.write("query-id\tcorpus-id\tscore\n") | |
for query in test_queries: | |
qrels_file.write(f"{query['query_id']}\t{query['corpus_id']}\t{query['score']}\n") | |
print("BEIR-kompatibles Dataset wurde erstellt.") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sentence_transformers import SentenceTransformer | |
from sentence_transformers.util import cos_sim | |
# 1. Specify preferred dimensions | |
dimensions = 1024 | |
# 2. Load model | |
model = SentenceTransformer("mixedbread-ai/deepset-mxbai-embed-de-large-v1", | |
truncate_dim=dimensions, | |
device="cuda") | |
query = "Was gab es heute zum Mittagessen?" | |
docs = [ | |
query, | |
"Embeddingmodelle spielen eine Schlüsselrolle in der natürlichen Sprachverarbeitung. Sie wandeln Wörter und Sätze in Vektoren um, die semantische Beziehungen zwischen Wörtern ausdrücken.", | |
"Ein bekanntes Embeddingmodell ist Word2Vec, das erstmals die Idee popularisierte, Wörter als dichte Vektoren in einem mehrdimensionalen Raum darzustellen.", | |
"Embeddingmodelle haben jedoch einige Schwächen. Zum Beispiel berücksichtigen statische Modelle wie Word2Vec nicht den Kontext von Wörtern in einem Satz.", | |
"Bei der Entwicklung von Embeddingmodellen ist es wichtig, große Datenmengen zu verwenden, um hochwertige Repräsentationen zu erzeugen, die verschiedene Bedeutungen eines Wortes korrekt erfassen können.", | |
"Unsere Bäckerei bietet eine Auswahl an frisch gebackenem Brot, darunter auch glutenfreie und biologische Optionen, die sich perfekt für Menschen mit besonderen Ernährungsbedürfnissen eignen.", | |
"Der Fortschritt in der Embedding-Technologie hat die Entwicklung von Modellen wie BERT ermöglicht, die den Kontext eines Wortes dynamisch in einem Satz berücksichtigen.", | |
"Kuchen und Gebäck gehören ebenfalls zu unserem Angebot, und wir legen großen Wert auf die Verwendung von natürlichen Zutaten, um höchste Qualität und Geschmack zu garantieren.", | |
"Transformer-Modelle wie BERT und GPT setzen neue Maßstäbe in der natürlichen Sprachverarbeitung, da sie es ermöglichen, kontextabhängige Bedeutungen besser zu erfassen als frühere Modelle." | |
] | |
# 3. Encode | |
embeddings = model.encode(docs) | |
similarities = cos_sim(embeddings[0], embeddings[1:]) | |
print('similarities:', similarities) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from beir import LoggingHandler | |
from beir.retrieval import models | |
from beir.datasets.data_loader import GenericDataLoader | |
from beir.retrieval.evaluation import EvaluateRetrieval | |
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES, DenseRetrievalParallelExactSearch as DRPES | |
import logging | |
import pathlib | |
import os | |
import time | |
import psutil | |
import torch | |
#### Just some code to print debug information to stdout | |
logging.basicConfig(format='%(asctime)s - %(message)s', | |
datefmt='%Y-%m-%d %H:%M:%S', | |
level=logging.INFO, | |
handlers=[LoggingHandler()]) | |
#### /print debug information to stdout | |
#### Pfad zu deinem eigenen Datensatz (statt scifact) | |
data_folder = os.path.join(pathlib.Path(__file__).parent.absolute(), "../beir_custom_dataset") | |
#### Lade deinen eigenen Datensatz | |
corpus, queries, qrels = GenericDataLoader(data_folder=data_folder).load(split="test") | |
formatted_queries = {k: f"query: {v}" for k, v in queries.items()} | |
formatted_corpus = {k: {**v, 'text': f"passage: {v['text']}"} for k, v in corpus.items()} | |
# Überprüfen, ob CUDA verfügbar ist | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
#### Zeitmessung starten | |
start_time = time.time() | |
#### GPU-Speicher vor dem Modell-Load leeren | |
if device == "cuda": | |
torch.cuda.empty_cache() | |
torch.cuda.reset_peak_memory_stats() | |
#### Messung des CPU-Speicherverbrauchs | |
process = psutil.Process(os.getpid()) | |
start_memory = process.memory_info().rss # CPU-Speicherverbrauch zu Beginn (in Bytes) | |
#### Lade dein Sentence-BERT Modell und verwende Kosinus-Ähnlichkeit | |
#model = DRES(models.SentenceBERT("sentence-transformers/all-mpnet-base-v2", device=device), batch_size=16) | |
model = DRES(models.SentenceBERT("mixedbread-ai/deepset-mxbai-embed-de-large-v1", device=device), batch_size=16) | |
#model = DRES(models.SentenceBERT("intfloat/multilingual-e5-large", device=device), batch_size=16) | |
#model = DRES(models.SentenceBERT("deepset/gbert-large", device=device), batch_size=16) | |
#model = DRES(models.SentenceBERT("google-bert/bert-base-multilingual-uncased", device=device), batch_size=16) | |
#Das eigene Modell | |
#model = DRES(models.SentenceBERT("output/trained-mixedbread-model", device=device), batch_size=16) | |
retriever = EvaluateRetrieval(model, score_function="cos_sim") # Verwende "cos_sim" für Kosinusähnlichkeit | |
results = retriever.retrieve(formatted_corpus, formatted_queries, device=device) | |
#### Evaluiere dein Modell mit NDCG@k, MAP@K, Recall@K und Precision@K, wobei k = [1,3,5,10,100,1000] | |
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values) | |
#### Zeitmessung beenden | |
end_time = time.time() | |
runtime = end_time - start_time | |
#### Speicherverbrauch messen (CPU) | |
end_memory = process.memory_info().rss | |
cpu_memory_usage = (end_memory - start_memory) / (1024 ** 2) # CPU-Speicherverbrauch in MB | |
#### GPU-Speicherverbrauch messen | |
if device == "cuda": | |
gpu_memory_usage = torch.cuda.max_memory_allocated() / (1024 ** 2) # GPU-Speicherverbrauch in MB | |
else: | |
gpu_memory_usage = 0 | |
#### Anzeigen der Ergebnisse | |
print(f"NDCG@10: {ndcg['NDCG@10']:.4f}") | |
print(f"MAP@10: {_map['MAP@10']:.4f}") | |
print(f"Recall@10: {recall['Recall@10']:.4f}") | |
print(f"Precision@10: {precision['P@10']:.4f}") | |
#### Ausgeben von Laufzeit, CPU- und GPU-Speicherverbrauch | |
print(f"Laufzeit des Tests: {runtime:.2f} Sekunden") | |
print(f"CPU-Speicherverbrauch des Embeddingmodells: {cpu_memory_usage:.2f} MB") | |
print(f"GPU-Speicherverbrauch des Embeddingmodells: {gpu_memory_usage:.2f} MB") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sentence_transformers import SentenceTransformer, SentencesDataset, InputExample, losses, models | |
from torch.utils.data import DataLoader | |
import json | |
import torch | |
from tqdm import tqdm | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Training auf {device}") | |
# Lade das vortrainierte Modell | |
model = SentenceTransformer('mixedbread-ai/deepset-mxbai-embed-de-large-v1', device=device) | |
model_save_path = "output/trained-mixedbread-model" | |
# Lade Queries als Dictionary | |
queries_dict = {} | |
with open("beir_custom_dataset/queries.jsonl", "r") as f: | |
for line in f: | |
query = json.loads(line) | |
queries_dict[query['_id']] = query['text'] | |
# Lade das Corpus | |
corpus = {} | |
with open("beir_custom_dataset/corpus.jsonl", "r") as f: | |
for line in f: | |
doc = json.loads(line) | |
corpus[doc['_id']] = doc['text'] | |
formatted_queries = {k: f"query: {v}" for k, v in queries_dict.items()} | |
formatted_corpus = {k: f"passage: {v}" for k, v in corpus.items()} | |
# Bereite die Training-Daten vor (als InputExample-Objekte) | |
train_examples = [] | |
with open("beir_custom_dataset/qrels/train.tsv", "r") as f: | |
next(f) # Überspringe die Kopfzeile | |
for line in f: | |
query_id, corpus_id, score = line.strip().split('\t') | |
train_examples.append(InputExample(texts=[formatted_queries[query_id], formatted_corpus[corpus_id]], label=float(score))) | |
# Erstelle ein Dataset für Sentence-Transformers | |
train_dataset = SentencesDataset(train_examples, model=model) | |
# DataLoader vorbereiten | |
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=16) | |
# Loss-Funktion wählen | |
train_loss = losses.MultipleNegativesRankingLoss(model) | |
# Modell trainieren | |
model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=3, warmup_steps=100) | |
print("Training abgeschlossen") | |
model.save(model_save_path) | |
print(f"Modell wurde unter {model_save_path} gespeichert.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment