Last active
April 4, 2025 09:24
-
-
Save MichelNivard/21734f228ec29d4c91fcb123f2ec4aaf to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Full training script for protein contact map diffusion model | |
# Using LucidRain's denoising-diffusion-pytorch (grayscale input) | |
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer | |
import matplotlib.pyplot as plt | |
import numpy as np | |
model = Unet( | |
dim = 64, | |
dim_mults = (1, 2, 4, 8), | |
flash_attn = False, | |
channels=1 | |
) | |
diffusion = GaussianDiffusion( | |
model, | |
image_size = 128, | |
timesteps = 1000, # number of steps | |
sampling_timesteps = 150 # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper]) | |
) | |
trainer = Trainer( | |
diffusion, | |
'chapter_5_diffusion/contact_maps_128/', | |
train_batch_size = 12, | |
train_lr = 8e-5, | |
train_num_steps = 600, # total training steps | |
gradient_accumulate_every = 16, # gradient accumulation steps | |
ema_decay = 0.995, # exponential moving average decay | |
calculate_fid = False # whether to calculate fid during training | |
) | |
trainer.train() | |
#### Visdualize model results | |
samples = diffusion.sample(batch_size=1, return_all_timesteps=True) | |
print("Samples shape:", samples.shape) | |
samples = samples[0] # shape: (timesteps, 1, 256, 256) | |
timesteps = samples.shape[0] | |
step_indices = np.linspace(0, timesteps - 1, 5, dtype=int) | |
fig, axs = plt.subplots(1, 5, figsize=(15, 3)) | |
for i, step_idx in enumerate(step_indices): | |
frame = samples[step_idx, 0].detach().cpu().numpy() # shape: (256, 256) | |
# Force symmetry along the diagonal | |
frame = (frame + frame.T) / 2 | |
axs[i].imshow(frame*-1, cmap='viridis') | |
axs[i].set_title(f"Step {step_idx}") | |
axs[i].axis('off') | |
plt.tight_layout() | |
plt.show() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment