Skip to content

Instantly share code, notes, and snippets.

@MichelNivard
Last active April 4, 2025 09:24
Show Gist options
  • Save MichelNivard/21734f228ec29d4c91fcb123f2ec4aaf to your computer and use it in GitHub Desktop.
Save MichelNivard/21734f228ec29d4c91fcb123f2ec4aaf to your computer and use it in GitHub Desktop.
# Full training script for protein contact map diffusion model
# Using LucidRain's denoising-diffusion-pytorch (grayscale input)
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
import matplotlib.pyplot as plt
import numpy as np
model = Unet(
dim = 64,
dim_mults = (1, 2, 4, 8),
flash_attn = False,
channels=1
)
diffusion = GaussianDiffusion(
model,
image_size = 128,
timesteps = 1000, # number of steps
sampling_timesteps = 150 # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
)
trainer = Trainer(
diffusion,
'chapter_5_diffusion/contact_maps_128/',
train_batch_size = 12,
train_lr = 8e-5,
train_num_steps = 600, # total training steps
gradient_accumulate_every = 16, # gradient accumulation steps
ema_decay = 0.995, # exponential moving average decay
calculate_fid = False # whether to calculate fid during training
)
trainer.train()
#### Visdualize model results
samples = diffusion.sample(batch_size=1, return_all_timesteps=True)
print("Samples shape:", samples.shape)
samples = samples[0] # shape: (timesteps, 1, 256, 256)
timesteps = samples.shape[0]
step_indices = np.linspace(0, timesteps - 1, 5, dtype=int)
fig, axs = plt.subplots(1, 5, figsize=(15, 3))
for i, step_idx in enumerate(step_indices):
frame = samples[step_idx, 0].detach().cpu().numpy() # shape: (256, 256)
# Force symmetry along the diagonal
frame = (frame + frame.T) / 2
axs[i].imshow(frame*-1, cmap='viridis')
axs[i].set_title(f"Step {step_idx}")
axs[i].axis('off')
plt.tight_layout()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment