Created
February 28, 2025 23:58
-
-
Save benthecoder/f253f713766122d97ca0399d7e882f09 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# NV-Embed-v2 Medical Text Embedding and Visualization | |
# This script demonstrates how to use NVIDIA's NV-Embed-v2 model for medical text embeddings | |
# and visualize the results using t-SNE | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from torch.nn import DataParallel | |
from tqdm import tqdm | |
from transformers import AutoModel | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
from matplotlib.patches import Ellipse | |
from sklearn.manifold import TSNE | |
from sklearn.preprocessing import StandardScaler | |
# Initialize model with multi-GPU support if available, otherwise use single GPU | |
def setup_model(model_name="nvidia/NV-Embed-v2", use_multi_gpu=True, cache_dir="~/cache"): | |
""" | |
Set up the NV-Embed model with appropriate GPU configuration | |
Args: | |
model_name: HuggingFace model identifier | |
use_multi_gpu: Whether to use multiple GPUs if available | |
cache_dir: Directory to cache model files | |
Returns: | |
Loaded model ready for inference | |
""" | |
try: | |
# Check GPU availability | |
if not torch.cuda.is_available(): | |
print("WARNING: No GPU detected. Running on CPU will be very slow.") | |
device = "cpu" | |
use_multi_gpu = False | |
else: | |
device = "cuda" | |
# Load model | |
model = AutoModel.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
cache_dir=cache_dir | |
) | |
# Configure for multi-GPU if requested and available | |
if use_multi_gpu and torch.cuda.device_count() > 1: | |
device_count = torch.cuda.device_count() | |
device_ids = list(range(device_count)) | |
print(f"Using {device_count} GPUs: {device_ids}") | |
for module_key, module in model._modules.items(): | |
model._modules[module_key] = DataParallel(module, device_ids=device_ids) | |
model.to(f"cuda:{device_ids[0]}") # Move to primary GPU | |
else: | |
# Single GPU or CPU | |
print(f"Using single device: {device}") | |
model.to(device) | |
return model, device | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
raise | |
def generate_embeddings( | |
df, | |
text_column, | |
model, | |
device, | |
batch_size=16, | |
instruction="", | |
max_length=32768, | |
output_file=None | |
): | |
""" | |
Generate embeddings for texts in a dataframe | |
Args: | |
df: Pandas DataFrame containing text data | |
text_column: Column name containing the text to embed | |
model: The embedding model | |
device: Device to run inference on (cuda/cpu) | |
batch_size: Number of texts to process at once | |
instruction: Optional instruction prefix for medical context | |
max_length: Maximum sequence length to process | |
output_file: Optional path to save embeddings | |
Returns: | |
NumPy array of embeddings | |
""" | |
embeddings = [] | |
# Medical instruction examples: | |
# instruction = "Represent this clinical note for finding similar patient cases:" | |
# instruction = "Encode this medical text for disease classification:" | |
# Process in batches | |
for i in tqdm(range(0, len(df), batch_size)): | |
try: | |
# Get batch of texts | |
batch_texts = df[text_column].iloc[i:i+batch_size].tolist() | |
# Handle potential CUDA OOM errors by reducing batch if needed | |
current_batch_size = len(batch_texts) | |
# Generate embeddings | |
with torch.no_grad(): # Disable gradient tracking for inference | |
batch_embeddings = model.encode( | |
batch_texts, | |
instruction=instruction, | |
max_length=max_length, | |
show_progress_bar=False, | |
) | |
# Normalize embeddings (L2 norm) | |
batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1) | |
embeddings.append(batch_embeddings) | |
except RuntimeError as e: | |
if "CUDA out of memory" in str(e) and batch_size > 1: | |
print(f"CUDA OOM error. Consider reducing batch_size or max_length.") | |
# Could implement auto-retry with smaller batch here | |
raise | |
else: | |
print(f"Error processing batch {i}-{i+batch_size}: {e}") | |
raise | |
# Concatenate all batches | |
if device == "cuda": | |
all_embeddings = torch.cat(embeddings, dim=0).cpu().numpy() | |
else: | |
all_embeddings = torch.cat(embeddings, dim=0).numpy() | |
# Save embeddings if output file specified | |
if output_file: | |
np.save(output_file, all_embeddings) | |
print(f"Saved embeddings to {output_file}") | |
print(f"Shape of embeddings: {all_embeddings.shape}") | |
return all_embeddings | |
def test_embedding_similarity(embeddings, texts, indices=None, top_k=5): | |
""" | |
Test similarity between embeddings to verify quality | |
Args: | |
embeddings: NumPy array of embeddings | |
texts: List of original texts corresponding to embeddings | |
indices: Optional list of specific indices to compare (default: random selection) | |
top_k: Number of most similar results to return | |
Returns: | |
None (prints results) | |
""" | |
if indices is None: | |
# Select a random index if none provided | |
query_idx = np.random.randint(0, len(embeddings)) | |
indices = [query_idx] | |
for query_idx in indices: | |
query_embedding = embeddings[query_idx:query_idx+1] | |
# Compute cosine similarity | |
similarities = np.dot(embeddings, query_embedding.T).flatten() | |
# Get top-k most similar (excluding self) | |
most_similar = np.argsort(similarities)[::-1][1:top_k+1] | |
print(f"\nQuery text [{query_idx}]:\n{texts[query_idx][:300]}...\n") | |
print("Most similar texts:") | |
for i, idx in enumerate(most_similar): | |
print(f"{i+1}. Similarity: {similarities[idx]:.4f}") | |
print(f" Text [{idx}]: {texts[idx][:200]}...\n") | |
def visualize_embeddings( | |
embeddings, | |
labels, | |
class_names, | |
title, | |
output_path=None, | |
colors=None, | |
): | |
"""Create t-SNE visualization of embeddings | |
Args: | |
embeddings: Array of embeddings to visualize. | |
labels: Array of class labels (integers). | |
class_names: List of class names corresponding to labels. | |
title: Title for the plot. | |
output_path: Optional path to save visualization. | |
colors: Optional list of colors for classes. | |
""" | |
print(f"Creating t-SNE visualization for {title}...") | |
# Set default colors if not provided | |
if colors is None: | |
colors = ["#E64B35", "#4DBBD5", "#00A087", "#3C5488"] | |
n_samples = len(embeddings) | |
perplexity = min(n_samples // 3, 50) # Dynamic perplexity based on dataset size | |
tsne = TSNE( | |
n_components=2, | |
perplexity=perplexity, | |
early_exaggeration=12, | |
learning_rate="auto", | |
init="pca", | |
random_state=42, | |
max_iter=2500, | |
min_grad_norm=1e-7, | |
metric="cosine", | |
n_iter_without_progress=300, | |
) | |
scaler = StandardScaler() | |
embeddings_scaled = scaler.fit_transform(embeddings) | |
# Try multiple perplexity values for more robust visualization | |
perplexity_values = [ | |
min(n_samples // 5, 30), | |
min(n_samples // 3, 50), | |
min(n_samples // 2, 100), | |
] | |
best_kl = float("inf") | |
best_embedding = None | |
for perp in perplexity_values: | |
print(f"Trying perplexity {perp}...") | |
tsne.set_params(perplexity=perp) | |
current_embedding = tsne.fit_transform(embeddings_scaled) | |
if tsne.kl_divergence_ < best_kl: | |
best_kl = tsne.kl_divergence_ | |
best_embedding = current_embedding | |
embeddings_2d = best_embedding | |
print("Creating plot...") | |
# Create plot | |
fig, ax = plt.subplots(figsize=(10, 8)) | |
fig.patch.set_facecolor("white") | |
ax.set_facecolor("white") | |
ax.grid(True, linestyle="-", alpha=0.2, color="gray", linewidth=0.5) | |
# Plot each class | |
for class_idx, class_name in enumerate(class_names): | |
mask = labels == class_idx | |
ax.scatter( | |
embeddings_2d[mask, 0], | |
embeddings_2d[mask, 1], | |
c=[colors[class_idx]], | |
label=class_name, | |
alpha=0.6, | |
s=40, | |
edgecolor="none", | |
) | |
# Add centroid marker for each class | |
centroid = np.mean(embeddings_2d[mask], axis=0) | |
ax.scatter( | |
centroid[0], | |
centroid[1], | |
c="black", | |
s=150, | |
marker="*", | |
edgecolor="white", | |
linewidth=1, | |
zorder=5, | |
) | |
# Add confidence ellipse if enough points exist | |
if np.sum(mask) > 2: | |
cov = np.cov(embeddings_2d[mask, 0], embeddings_2d[mask, 1]) | |
eigenvals, eigenvecs = np.linalg.eigh(cov) | |
angle = np.degrees(np.arctan2(eigenvecs[1, 0], eigenvecs[0, 0])) | |
ellip = Ellipse( | |
xy=centroid, | |
width=2 * 2 * np.sqrt(eigenvals[0]), | |
height=2 * 2 * np.sqrt(eigenvals[1]), | |
angle=angle, | |
facecolor="none", | |
edgecolor=colors[class_idx], | |
alpha=0.3, | |
linestyle="--", | |
linewidth=0.75, | |
) | |
ax.add_patch(ellip) | |
# Formatting the plot | |
ax.spines["top"].set_visible(False) | |
ax.spines["right"].set_visible(False) | |
for spine in ax.spines.values(): | |
spine.set_linewidth(0.75) | |
ax.set_title(title, fontsize=11, pad=10, weight="bold") | |
ax.set_xlabel("t-SNE Dimension 1", fontsize=10, labelpad=8) | |
ax.set_ylabel("t-SNE Dimension 2", fontsize=10, labelpad=8) | |
ax.tick_params(axis="both", which="major", labelsize=9, width=0.75, length=4) | |
# Add legend and sample size annotation | |
legend = ax.legend( | |
title="Brain Tumor Subtypes", | |
title_fontsize=10, | |
fontsize=9, | |
bbox_to_anchor=(1.05, 1), | |
loc="upper left", | |
borderaxespad=0, | |
frameon=True, | |
edgecolor="black", | |
fancybox=False, | |
) | |
legend.get_frame().set_linewidth(0.75) | |
ax.text( | |
0.02, | |
0.98, | |
f"n = {len(embeddings)}", | |
transform=ax.transAxes, | |
fontsize=9, | |
verticalalignment="top", | |
) | |
plt.tight_layout() | |
if output_path: | |
plt.savefig(output_path, dpi=600, bbox_inches="tight") | |
plt.close() | |
print(f"Saved visualization to {output_path}") | |
else: | |
plt.show() | |
# Main execution example | |
if __name__ == "__main__": | |
# Load the model | |
model, device = setup_model(use_multi_gpu=True) | |
# Load data | |
notes_df = pd.read_csv("data/brain_tumor_subtypes_cleaned.csv") | |
# Define class names for visualization | |
class_names = [ | |
"glioblastoma", # 0 | |
"astrocytoma", # 1 | |
"oligodendroglioma", # 2 | |
"oligoastrocytoma", # 3 | |
] | |
# Either load pre-computed embeddings or generate them | |
try: | |
# Try to load existing embeddings | |
embeddings = np.load("nv_embed_subtype.npy") | |
print(f"Loaded embeddings with shape: {embeddings.shape}") | |
except FileNotFoundError: | |
# Generate embeddings if file doesn't exist | |
print("Embeddings file not found. Generating new embeddings...") | |
embeddings = generate_embeddings( | |
df=notes_df, | |
text_column="clinical_note", # Replace with your text column | |
model=model, | |
device=device, | |
batch_size=16, | |
instruction="Represent this clinical note for tumor subtype classification:", | |
max_length=32768, | |
output_file="nv_embed_subtype.npy" | |
) | |
# Test similarity on a few examples | |
test_embedding_similarity( | |
embeddings=embeddings, | |
texts=notes_df["clinical_note"].tolist(), # Replace with your text column | |
indices=[0, 10, 20], # Test a few specific examples | |
top_k=3 | |
) | |
# Create visualization | |
labels = notes_df["class_label"].values | |
custom_colors = ["#E64B35", "#4DBBD5", "#00A087", "#3C5488"] # Color-blind friendly palette | |
visualize_embeddings( | |
embeddings=embeddings, | |
labels=labels, | |
class_names=class_names, | |
title="Brain Tumor Subtypes - NV-Embed-v2", | |
output_path="brain_tumor_embeddings.png", | |
colors=custom_colors | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment