Last active
March 28, 2025 18:42
-
-
Save jogonba2/bcdc73285d4de86e9df00d380aa61f0a to your computer and use it in GitHub Desktop.
Prototype tuning model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
os.environ["CUDA_VISIBLE_DEVICES"] = "2" | |
from torch.nn import functional as F | |
from torch import nn | |
import torch | |
from typing import Optional | |
from sentence_transformers import SentenceTransformer | |
from transformers.modeling_outputs import SequenceClassifierOutput | |
from datasets import Dataset | |
from transformers import Trainer, TrainingArguments | |
from sklearn.metrics import f1_score | |
class PrototypeTuningModel(nn.Module): | |
def __init__( | |
self, | |
prototype_embeddings: torch.FloatTensor, | |
dropout: float, | |
drift_coefficient: float, | |
): | |
super().__init__() | |
self.prototype_embeddings = nn.Parameter( | |
data=prototype_embeddings, requires_grad=True | |
) | |
self.dropout = dropout | |
self.drift_coefficient = drift_coefficient | |
self.initial_prototype_embeddings = ( | |
prototype_embeddings.detach().clone() | |
) | |
self.device = torch.device("cpu") | |
def to(self, *args, **kwargs): | |
model = super().to(*args, **kwargs) | |
self.device = next(model.parameters()).device | |
return model | |
def get_trainable_params(self) -> int: | |
num_trainable_params = 0 | |
for param in self.parameters(): | |
if param.requires_grad: | |
num_trainable_params += param.numel() | |
return num_trainable_params | |
def forward( | |
self, | |
sentence_embeddings: torch.FloatTensor, | |
labels: Optional[torch.LongTensor] = None, | |
) -> SequenceClassifierOutput: | |
""" | |
Forward method of PrototypeTuningModel. Computes the logits | |
for each input sentence by means of the dot product between | |
the sentence and the prototype embeddings. | |
Args: | |
sentence_embeddings (FloatTensor): input sentence embeddings, of shape (batch_size, dim) | |
labels (Optional[LongTensor]): labels of each input sentence, of shape (batch_size, num_questions) | |
Returns: | |
SequenceClassifierOutput: output containing the logits and the loss if labels are passed as argument. | |
""" | |
batch_size = sentence_embeddings.shape[0] | |
num_questions, num_prototypes, _ = self.prototype_embeddings.shape | |
# Apply dropout to prototype embeddings to prevent overfitting | |
prototype_embeddings = F.dropout( | |
self.prototype_embeddings, training=self.training | |
) | |
# Get the dot product between queries prototypes and sentence embeddings | |
# (batch_size, num_queries, num_prototypes) | |
logits = torch.einsum( | |
"qpd,bd -> bqp", prototype_embeddings, sentence_embeddings | |
) | |
# Reshape the logits to separate logits for relevant and non-relevant prototypes | |
# (batch_size, num_questions, 2, num_prototypes/2) | |
logits = logits.view(batch_size, num_questions, 2, num_prototypes // 2) | |
# Compute the mean of the logits for relevant and non-relevant prototypes | |
# (batch_size, num_questions, 2) | |
logits = logits.mean(-1) | |
# Compute loss if labels are provided. | |
# The loss is defined as: | |
# L = cross_entropy(logits, labels) + drift_coefficient * norm(prototype_embeddings) | |
# The first part is classical cross-entropy, while the second one is a drift loss that prevents | |
# the embeddings to move so much from the initial embeddings (overfitting). You can control the | |
# overfitting by modifying the `drift_coefficient` term. | |
loss = None | |
if labels is not None: | |
loss_fn = nn.CrossEntropyLoss() | |
cross_entropy_loss = loss_fn(logits.view(-1, 2), labels.view(-1)) | |
drift_loss = ( | |
self.prototype_embeddings - self.initial_prototype_embeddings | |
).norm(p=2) | |
loss = cross_entropy_loss + self.drift_coefficient * drift_loss | |
return SequenceClassifierOutput(logits=logits, loss=loss) | |
def prepare_dataset( | |
data: dict, embedding_model: SentenceTransformer, batch_size: int = 16 | |
) -> Dataset: | |
""" | |
Prepare a dataset by encoding the text data into sentence embeddings using a SentenceTransformer model. | |
Args: | |
data (dict): A dictionary containing the text data and corresponding labels. | |
- 'text': A list of strings representing the input sentences. | |
- 'labels': A list of labels corresponding to the input sentences. | |
embedding_model (SentenceTransformer): A pre-trained SentenceTransformer model used to encode the text data. | |
batch_size (int, optional): The batch size to use when encoding the text data. Defaults to 16. | |
Returns: | |
Dataset: A PyTorch-compatible dataset containing the sentence embeddings and labels. | |
""" | |
dataset = Dataset.from_dict(data) | |
dataset = dataset.map( | |
lambda texts: { | |
"sentence_embeddings": embedding_model.encode( | |
texts, convert_to_tensor=True | |
) | |
}, | |
input_columns=["text"], | |
batched=True, | |
batch_size=batch_size, | |
remove_columns=["text"], | |
) | |
dataset.set_format(type="torch", columns=["sentence_embeddings", "labels"]) | |
return dataset | |
def get_prototype_embeddings( | |
prototypes: list[list[str]], embedding_model: SentenceTransformer | |
) -> torch.FloatTensor: | |
""" | |
Generate prototype embeddings for a given set of prototypes using a SentenceTransformer model. | |
Args: | |
prototypes (list[list[str]]): A list of lists, where each inner list contains strings representing | |
the prototypes for a specific query or question. | |
embedding_model (SentenceTransformer): A pre-trained SentenceTransformer model used to encode the prototypes. | |
Returns: | |
torch.FloatTensor: A tensor of shape (num_queries, num_prototypes_per_query, embedding_dim) containing | |
the embeddings for all prototypes. | |
""" | |
prototype_embeddings = [] | |
for i in range(len(prototypes)): | |
prototype_embeddings.append( | |
embedding_model.encode(prototypes[i], convert_to_tensor=True) | |
) | |
return torch.stack(prototype_embeddings, dim=0) | |
# Let's use synthetic dummy data for the example. | |
# First 5 prototypes are non-relevant for the question, last 5 are relevant | |
prototypes = [ | |
# Prototypes of the first question ("Is this person sad?") | |
[ | |
"I love playing video games with my friends.", | |
"Tomorrow is going to be an exciting day!", | |
"I need to buy groceries after work.", | |
"The weather is really nice today.", | |
"Learning a new language is challenging but fun.", | |
"I feel so empty inside, nothing makes me happy anymore.", | |
"I cried myself to sleep last night.", | |
"I don't think I'll ever be happy again.", | |
"Everything feels so meaningless lately.", | |
"I just lost someone important to me, and it hurts so much.", | |
], | |
# Prototypes of the second question ("Is this person pessimist?") | |
[ | |
"I enjoy cooking new recipes on the weekend.", | |
"I'm so excited for my vacation next month!", | |
"Exercising regularly makes me feel great.", | |
"I had a great time at the party last night.", | |
"Spending time with family is always fun.", | |
"No matter how hard I try, I always fail.", | |
"Things never work out for me, so why bother?", | |
"There's no point in hoping for the best—it never happens.", | |
"The world is a terrible place, and it’s only getting worse.", | |
"I expect things to go wrong because they always do.", | |
], | |
] | |
train_data = { | |
"text": [ | |
"I am a good football player", | |
"I'm sad because I will not find any job", | |
"That's a good idea", | |
"I feel so lonely and depressed", | |
"Nothing ever works out for me", | |
"Life is full of disappointments", | |
"I love spending time with my friends", | |
"Why does everything always go wrong?", | |
"I just lost my best friend and I feel empty", | |
"I am optimistic about the future", | |
"There's no hope left for me", | |
"I will never be successful in life", | |
"I enjoy painting, it makes me happy", | |
"I failed my exam, I feel terrible", | |
"Nothing matters anymore", | |
"I had a tough day, but tomorrow will be better", | |
"I always expect the worst in any situation", | |
"I am confident that things will improve", | |
"It's pointless to try because I always fail", | |
"I just want to disappear", | |
"I feel like crying all the time", | |
"There’s no point in making plans, they always fail", | |
"I can’t stop smiling today", | |
], | |
"labels": [ | |
[0, 0], | |
[1, 1], | |
[0, 0], | |
[1, 0], | |
[0, 1], | |
[0, 1], | |
[0, 0], | |
[0, 1], | |
[1, 0], | |
[0, 0], | |
[1, 1], | |
[0, 1], | |
[0, 0], | |
[1, 0], | |
[1, 1], | |
[0, 0], | |
[0, 1], | |
[0, 0], | |
[0, 1], | |
[1, 0], | |
[1, 0], | |
[0, 1], | |
[0, 0], | |
], | |
} | |
test_data = { | |
"text": [ | |
"I enjoy going for long walks in the park.", | |
"No one understands my pain, I feel so alone.", | |
"I always try to find the bright side in every situation.", | |
"Why does everything bad happen to me?", | |
"I am looking forward to my weekend getaway!", | |
"I don't see the point in trying anymore.", | |
"I feel so lost and broken inside.", | |
"Hanging out with my friends makes me feel so much better.", | |
"Life is unfair, and things never go my way.", | |
"I love reading books about history.", | |
"I am devastated after failing that important test.", | |
"No matter how much effort I put in, I always end up failing.", | |
"There’s no use in expecting good things to happen.", | |
"I can’t stop crying, everything is falling apart.", | |
"I had an amazing time at the concert last night.", | |
"Things are tough, but I believe they will improve soon.", | |
"I always assume the worst outcome in any situation.", | |
"There's no point in making friends, they always leave in the end.", | |
"I’m feeling great after accomplishing my goals!", | |
"Everything feels meaningless, I just want to give up.", | |
], | |
"labels": [ | |
[0, 0], | |
[1, 0], | |
[0, 0], | |
[0, 1], | |
[0, 0], | |
[0, 1], | |
[1, 0], | |
[0, 0], | |
[0, 1], | |
[0, 0], | |
[1, 0], | |
[0, 1], | |
[0, 1], | |
[1, 0], | |
[0, 0], | |
[0, 0], | |
[0, 1], | |
[0, 1], | |
[0, 0], | |
[1, 0], | |
], | |
} | |
model_name_or_path = "multi-qa-mpnet-base-dot-v1" | |
# Prepare embedding model, dataset, and prototype embeddings | |
embedding_model = SentenceTransformer(model_name_or_path) | |
train_dataset = prepare_dataset(train_data, embedding_model) | |
test_dataset = prepare_dataset(test_data, embedding_model) | |
prototype_embeddings = get_prototype_embeddings(prototypes, embedding_model) | |
# Instantiate the model | |
model = PrototypeTuningModel( | |
prototype_embeddings, dropout=0.1, drift_coefficient=0.01 | |
) | |
model.to("cuda") | |
# Instantiate the trainer | |
training_args = TrainingArguments( | |
do_train=True, | |
output_dir="./", | |
save_strategy="no", | |
num_train_epochs=20, | |
per_device_train_batch_size=8, | |
learning_rate=1e-3, | |
logging_steps=5, | |
report_to="none", | |
) | |
trainer = Trainer(model=model, train_dataset=train_dataset, args=training_args) | |
# Let's check the performance in zero-shot | |
print("* Zero-shot performance *") | |
num_queries = prototype_embeddings.shape[0] | |
logits = trainer.predict(test_dataset).predictions | |
preds = logits.argmax(-1) | |
for i in range(num_queries): | |
print( | |
f"Macro-F1 in query {i}:", | |
f1_score( | |
y_true=test_dataset["labels"][:, i], | |
y_pred=preds[:, i], | |
average="macro", | |
), | |
) | |
# Now finetune and check the improvements | |
print("* Performance after fine-tuning *") | |
trainer.train() | |
num_queries = prototype_embeddings.shape[0] | |
logits = trainer.predict(test_dataset).predictions | |
preds = logits.argmax(-1) | |
for i in range(num_queries): | |
print( | |
f"Macro-F1 in query {i}:", | |
f1_score( | |
y_true=test_dataset["labels"][:, i], | |
y_pred=preds[:, i], | |
average="macro", | |
), | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment