Skip to content

Instantly share code, notes, and snippets.

@benthecoder
Created February 28, 2025 23:58
Show Gist options
  • Save benthecoder/f253f713766122d97ca0399d7e882f09 to your computer and use it in GitHub Desktop.
Save benthecoder/f253f713766122d97ca0399d7e882f09 to your computer and use it in GitHub Desktop.
# NV-Embed-v2 Medical Text Embedding and Visualization
# This script demonstrates how to use NVIDIA's NV-Embed-v2 model for medical text embeddings
# and visualize the results using t-SNE
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn import DataParallel
from tqdm import tqdm
from transformers import AutoModel
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.patches import Ellipse
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler
# Initialize model with multi-GPU support if available, otherwise use single GPU
def setup_model(model_name="nvidia/NV-Embed-v2", use_multi_gpu=True, cache_dir="~/cache"):
"""
Set up the NV-Embed model with appropriate GPU configuration
Args:
model_name: HuggingFace model identifier
use_multi_gpu: Whether to use multiple GPUs if available
cache_dir: Directory to cache model files
Returns:
Loaded model ready for inference
"""
try:
# Check GPU availability
if not torch.cuda.is_available():
print("WARNING: No GPU detected. Running on CPU will be very slow.")
device = "cpu"
use_multi_gpu = False
else:
device = "cuda"
# Load model
model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True,
cache_dir=cache_dir
)
# Configure for multi-GPU if requested and available
if use_multi_gpu and torch.cuda.device_count() > 1:
device_count = torch.cuda.device_count()
device_ids = list(range(device_count))
print(f"Using {device_count} GPUs: {device_ids}")
for module_key, module in model._modules.items():
model._modules[module_key] = DataParallel(module, device_ids=device_ids)
model.to(f"cuda:{device_ids[0]}") # Move to primary GPU
else:
# Single GPU or CPU
print(f"Using single device: {device}")
model.to(device)
return model, device
except Exception as e:
print(f"Error loading model: {e}")
raise
def generate_embeddings(
df,
text_column,
model,
device,
batch_size=16,
instruction="",
max_length=32768,
output_file=None
):
"""
Generate embeddings for texts in a dataframe
Args:
df: Pandas DataFrame containing text data
text_column: Column name containing the text to embed
model: The embedding model
device: Device to run inference on (cuda/cpu)
batch_size: Number of texts to process at once
instruction: Optional instruction prefix for medical context
max_length: Maximum sequence length to process
output_file: Optional path to save embeddings
Returns:
NumPy array of embeddings
"""
embeddings = []
# Medical instruction examples:
# instruction = "Represent this clinical note for finding similar patient cases:"
# instruction = "Encode this medical text for disease classification:"
# Process in batches
for i in tqdm(range(0, len(df), batch_size)):
try:
# Get batch of texts
batch_texts = df[text_column].iloc[i:i+batch_size].tolist()
# Handle potential CUDA OOM errors by reducing batch if needed
current_batch_size = len(batch_texts)
# Generate embeddings
with torch.no_grad(): # Disable gradient tracking for inference
batch_embeddings = model.encode(
batch_texts,
instruction=instruction,
max_length=max_length,
show_progress_bar=False,
)
# Normalize embeddings (L2 norm)
batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1)
embeddings.append(batch_embeddings)
except RuntimeError as e:
if "CUDA out of memory" in str(e) and batch_size > 1:
print(f"CUDA OOM error. Consider reducing batch_size or max_length.")
# Could implement auto-retry with smaller batch here
raise
else:
print(f"Error processing batch {i}-{i+batch_size}: {e}")
raise
# Concatenate all batches
if device == "cuda":
all_embeddings = torch.cat(embeddings, dim=0).cpu().numpy()
else:
all_embeddings = torch.cat(embeddings, dim=0).numpy()
# Save embeddings if output file specified
if output_file:
np.save(output_file, all_embeddings)
print(f"Saved embeddings to {output_file}")
print(f"Shape of embeddings: {all_embeddings.shape}")
return all_embeddings
def test_embedding_similarity(embeddings, texts, indices=None, top_k=5):
"""
Test similarity between embeddings to verify quality
Args:
embeddings: NumPy array of embeddings
texts: List of original texts corresponding to embeddings
indices: Optional list of specific indices to compare (default: random selection)
top_k: Number of most similar results to return
Returns:
None (prints results)
"""
if indices is None:
# Select a random index if none provided
query_idx = np.random.randint(0, len(embeddings))
indices = [query_idx]
for query_idx in indices:
query_embedding = embeddings[query_idx:query_idx+1]
# Compute cosine similarity
similarities = np.dot(embeddings, query_embedding.T).flatten()
# Get top-k most similar (excluding self)
most_similar = np.argsort(similarities)[::-1][1:top_k+1]
print(f"\nQuery text [{query_idx}]:\n{texts[query_idx][:300]}...\n")
print("Most similar texts:")
for i, idx in enumerate(most_similar):
print(f"{i+1}. Similarity: {similarities[idx]:.4f}")
print(f" Text [{idx}]: {texts[idx][:200]}...\n")
def visualize_embeddings(
embeddings,
labels,
class_names,
title,
output_path=None,
colors=None,
):
"""Create t-SNE visualization of embeddings
Args:
embeddings: Array of embeddings to visualize.
labels: Array of class labels (integers).
class_names: List of class names corresponding to labels.
title: Title for the plot.
output_path: Optional path to save visualization.
colors: Optional list of colors for classes.
"""
print(f"Creating t-SNE visualization for {title}...")
# Set default colors if not provided
if colors is None:
colors = ["#E64B35", "#4DBBD5", "#00A087", "#3C5488"]
n_samples = len(embeddings)
perplexity = min(n_samples // 3, 50) # Dynamic perplexity based on dataset size
tsne = TSNE(
n_components=2,
perplexity=perplexity,
early_exaggeration=12,
learning_rate="auto",
init="pca",
random_state=42,
max_iter=2500,
min_grad_norm=1e-7,
metric="cosine",
n_iter_without_progress=300,
)
scaler = StandardScaler()
embeddings_scaled = scaler.fit_transform(embeddings)
# Try multiple perplexity values for more robust visualization
perplexity_values = [
min(n_samples // 5, 30),
min(n_samples // 3, 50),
min(n_samples // 2, 100),
]
best_kl = float("inf")
best_embedding = None
for perp in perplexity_values:
print(f"Trying perplexity {perp}...")
tsne.set_params(perplexity=perp)
current_embedding = tsne.fit_transform(embeddings_scaled)
if tsne.kl_divergence_ < best_kl:
best_kl = tsne.kl_divergence_
best_embedding = current_embedding
embeddings_2d = best_embedding
print("Creating plot...")
# Create plot
fig, ax = plt.subplots(figsize=(10, 8))
fig.patch.set_facecolor("white")
ax.set_facecolor("white")
ax.grid(True, linestyle="-", alpha=0.2, color="gray", linewidth=0.5)
# Plot each class
for class_idx, class_name in enumerate(class_names):
mask = labels == class_idx
ax.scatter(
embeddings_2d[mask, 0],
embeddings_2d[mask, 1],
c=[colors[class_idx]],
label=class_name,
alpha=0.6,
s=40,
edgecolor="none",
)
# Add centroid marker for each class
centroid = np.mean(embeddings_2d[mask], axis=0)
ax.scatter(
centroid[0],
centroid[1],
c="black",
s=150,
marker="*",
edgecolor="white",
linewidth=1,
zorder=5,
)
# Add confidence ellipse if enough points exist
if np.sum(mask) > 2:
cov = np.cov(embeddings_2d[mask, 0], embeddings_2d[mask, 1])
eigenvals, eigenvecs = np.linalg.eigh(cov)
angle = np.degrees(np.arctan2(eigenvecs[1, 0], eigenvecs[0, 0]))
ellip = Ellipse(
xy=centroid,
width=2 * 2 * np.sqrt(eigenvals[0]),
height=2 * 2 * np.sqrt(eigenvals[1]),
angle=angle,
facecolor="none",
edgecolor=colors[class_idx],
alpha=0.3,
linestyle="--",
linewidth=0.75,
)
ax.add_patch(ellip)
# Formatting the plot
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
for spine in ax.spines.values():
spine.set_linewidth(0.75)
ax.set_title(title, fontsize=11, pad=10, weight="bold")
ax.set_xlabel("t-SNE Dimension 1", fontsize=10, labelpad=8)
ax.set_ylabel("t-SNE Dimension 2", fontsize=10, labelpad=8)
ax.tick_params(axis="both", which="major", labelsize=9, width=0.75, length=4)
# Add legend and sample size annotation
legend = ax.legend(
title="Brain Tumor Subtypes",
title_fontsize=10,
fontsize=9,
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0,
frameon=True,
edgecolor="black",
fancybox=False,
)
legend.get_frame().set_linewidth(0.75)
ax.text(
0.02,
0.98,
f"n = {len(embeddings)}",
transform=ax.transAxes,
fontsize=9,
verticalalignment="top",
)
plt.tight_layout()
if output_path:
plt.savefig(output_path, dpi=600, bbox_inches="tight")
plt.close()
print(f"Saved visualization to {output_path}")
else:
plt.show()
# Main execution example
if __name__ == "__main__":
# Load the model
model, device = setup_model(use_multi_gpu=True)
# Load data
notes_df = pd.read_csv("data/brain_tumor_subtypes_cleaned.csv")
# Define class names for visualization
class_names = [
"glioblastoma", # 0
"astrocytoma", # 1
"oligodendroglioma", # 2
"oligoastrocytoma", # 3
]
# Either load pre-computed embeddings or generate them
try:
# Try to load existing embeddings
embeddings = np.load("nv_embed_subtype.npy")
print(f"Loaded embeddings with shape: {embeddings.shape}")
except FileNotFoundError:
# Generate embeddings if file doesn't exist
print("Embeddings file not found. Generating new embeddings...")
embeddings = generate_embeddings(
df=notes_df,
text_column="clinical_note", # Replace with your text column
model=model,
device=device,
batch_size=16,
instruction="Represent this clinical note for tumor subtype classification:",
max_length=32768,
output_file="nv_embed_subtype.npy"
)
# Test similarity on a few examples
test_embedding_similarity(
embeddings=embeddings,
texts=notes_df["clinical_note"].tolist(), # Replace with your text column
indices=[0, 10, 20], # Test a few specific examples
top_k=3
)
# Create visualization
labels = notes_df["class_label"].values
custom_colors = ["#E64B35", "#4DBBD5", "#00A087", "#3C5488"] # Color-blind friendly palette
visualize_embeddings(
embeddings=embeddings,
labels=labels,
class_names=class_names,
title="Brain Tumor Subtypes - NV-Embed-v2",
output_path="brain_tumor_embeddings.png",
colors=custom_colors
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment