Created
January 8, 2025 16:36
-
-
Save kami619/6c89e4a2099c4a9060d5c4144db2f4f8 to your computer and use it in GitHub Desktop.
Improved and modularized code for the text classification using aggression detection model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import torch | |
import torch.nn.functional as F | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
# Placeholder for speech-to-text library (e.g., SpeechRecognition) | |
# import speech_recognition as sr | |
import os | |
import numpy as np | |
import pandas as pd | |
from torch.utils.data import Dataset, TensorDataset, DataLoader, RandomSampler | |
#from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_scheduler | |
from sklearn.model_selection import train_test_split | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup | |
from sklearn.metrics import classification_report, accuracy_score, matthews_corrcoef | |
from tqdm import trange | |
class AggressionDataset(Dataset): | |
def __init__(self, texts, labels, tokenizer, max_len): | |
self.texts = texts | |
self.labels = labels | |
self.tokenizer = tokenizer | |
self.max_len = max_len | |
def __len__(self): | |
return len(self.texts) | |
def __getitem__(self, item): | |
text = str(self.texts[item]) | |
label = self.labels[item] | |
encoding = self.tokenizer.encode_plus( | |
text, | |
add_special_tokens=True, | |
max_length=self.max_len, | |
return_token_type_ids=False, | |
pad_to_max_length=True, | |
return_attention_mask=True, | |
return_tensors='pt', | |
) | |
return { | |
'text': text, | |
'input_ids': encoding['input_ids'].flatten(), | |
'attention_mask': encoding['attention_mask'].flatten(), | |
'labels': torch.tensor(label, dtype=torch.long) | |
} | |
def load_data(filepath): | |
df = pd.read_csv(filepath) | |
texts = df['text'].tolist() | |
labels = df['label'].tolist() | |
return texts, labels | |
def prepare_data(texts, labels, tokenizer, max_len, batch_size): | |
dataset = AggressionDataset(texts, labels, tokenizer, max_len) | |
return DataLoader(dataset, batch_size=batch_size) | |
def train(model, data_loader, optimizer, scheduler, device, num_epochs): | |
model = model.to(device) | |
for epoch in trange(num_epochs, desc="Epoch"): | |
model.train() | |
total_loss = 0 | |
for batch in data_loader: | |
model.zero_grad() | |
input_ids = batch['input_ids'].to(device) | |
attention_mask = batch['attention_mask'].to(device) | |
labels = batch['labels'].to(device) | |
outputs = model(input_ids, attention_mask=attention_mask, labels=labels) | |
loss = outputs[0] | |
total_loss += loss.item() | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
optimizer.step() | |
scheduler.step() | |
avg_train_loss = total_loss / len(data_loader) | |
print(f"Average training loss: {avg_train_loss:.2f}") | |
def evaluate(model, data_loader, device): | |
model.eval() | |
predictions, true_labels = [], [] | |
for batch in data_loader: | |
batch = tuple(t.to(device) for t in batch) | |
input_ids, attention_mask, labels = batch | |
with torch.no_grad(): | |
outputs = model(input_ids, attention_mask=attention_mask) | |
logits = outputs[0] | |
logits = logits.detach().cpu().numpy() | |
label_ids = labels.to('cpu').numpy() | |
predictions.append(logits) | |
true_labels.append(label_ids) | |
predictions = np.concatenate(predictions, axis=0) | |
true_labels = np.concatenate(true_labels, axis=0) | |
return predictions, true_labels | |
def predict_aggression(text, model, tokenizer, device, max_length=512): | |
""" | |
Predicts the aggression level of a given text. | |
Args: | |
text: The input text (string). | |
model: The pre-trained aggression detection model. | |
tokenizer: The tokenizer corresponding to the model. | |
device: The device (CPU or CUDA) to run the model on. | |
max_length: The maximum sequence length for tokenization. | |
Returns: | |
A dictionary containing the predicted label and probabilities. | |
""" | |
try: | |
encoded_input = tokenizer( | |
text, | |
padding=True, | |
truncation=True, | |
max_length=max_length, | |
return_tensors="pt" | |
) | |
encoded_input = {k: v.to(device) for k, v in encoded_input.items()} | |
model.eval() | |
with torch.no_grad(): | |
outputs = model(**encoded_input) | |
logits = outputs.logits | |
probabilities = F.softmax(logits, dim=1) | |
predicted_class = torch.argmax(probabilities, dim=1).item() | |
predicted_probability = probabilities[0][predicted_class].item() | |
labels = ['aggressive', 'normal'] | |
predicted_label = labels[predicted_class] | |
return { | |
"predicted_label": predicted_label, | |
"predicted_probability": predicted_probability | |
} | |
except Exception as e: | |
print(f"Error during prediction: {e}") | |
return { | |
"predicted_label": "Error", | |
"predicted_probability": 0.0 | |
} | |
def main(): | |
parser = argparse.ArgumentParser(description="Aggression Detection with Transformers") | |
parser.add_argument("--model_path", type=str, default="./output/Modernbert2", | |
help="Path to the pre-trained model directory") | |
parser.add_argument("--train_data", type=str, default="./train.csv", | |
help="Path to the training data file") | |
parser.add_argument("--test_data", type=str, default="./test.csv", | |
help="Path to the test data file") | |
parser.add_argument("--num_epochs", type=int, default=5, | |
help="Number of training epochs") | |
parser.add_argument("--batch_size", type=int, default=8, | |
help="Batch size for training and evaluation") | |
parser.add_argument("--learning_rate", type=float, default=5e-5, | |
help="Learning rate for optimizer") | |
parser.add_argument("--max_len", type=int, default=512, | |
help="Maximum sequence length") | |
parser.add_argument("--text_input", type=str, default="how do you do", | |
help="Text input for aggression detection") | |
#parser.add_argument("--audio_input", type=str, | |
# help="Path to audio file for speech-to-text and aggression detection") | |
args = parser.parse_args() | |
# Check if CUDA is available, otherwise use CPU | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load or initialize the model and tokenizer | |
if os.path.exists(args.model_path): | |
print("Loading the pre-trained model...") | |
model = AutoModelForSequenceClassification.from_pretrained(args.model_path) | |
tokenizer = AutoTokenizer.from_pretrained(args.model_path) | |
else: | |
print("Training the model...") | |
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2) | |
train_texts, train_labels = load_data(args.train_data) | |
train_data_loader = prepare_data(train_texts, train_labels, tokenizer, args.max_len, args.batch_size) | |
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) | |
num_training_steps = args.num_epochs * len(train_data_loader) | |
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=num_training_steps) | |
train(model, train_data_loader, optimizer, scheduler, device, args.num_epochs) | |
# Save the trained model | |
model.save_pretrained(args.model_path) | |
tokenizer.save_pretrained(args.model_path) | |
# If audio input is provided, perform speech-to-text | |
#if args.audio_input: | |
# try: | |
# r = sr.Recognizer() | |
# with sr.AudioFile(args.audio_input) as source: | |
# audio_data = r.record(source) | |
# text = r.recognize_google(audio_data) | |
# print(f"Transcribed text: {text}") | |
# except Exception as e: | |
# print(f"Error during speech-to-text: {e}") | |
# return | |
# Perform aggression detection on text input | |
if args.text_input: | |
result = predict_aggression(args.text_input, model, tokenizer, device, args.max_len) | |
print(f"Predicted label: {result['predicted_label']}") | |
print(f"Predicted probability: {result['predicted_probability']:.4f}") | |
# Evaluate the model | |
test_texts, test_labels = load_data(args.test_data) | |
test_data_loader = prepare_data(test_texts, test_labels, tokenizer, args.max_len, args.batch_size) | |
predictions, true_labels = evaluate(model, test_data_loader, device) | |
print(classification_report(true_labels, np.argmax(predictions, axis=1))) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment