Skip to content

Instantly share code, notes, and snippets.

@g023
Created December 8, 2025 13:07
Show Gist options
  • Select an option

  • Save g023/461ff5a63761e99f055710ebfdefa3b5 to your computer and use it in GitHub Desktop.

Select an option

Save g023/461ff5a63761e99f055710ebfdefa3b5 to your computer and use it in GitHub Desktop.
Simulated diffusion inferencing example combining DistilGPT-2 and DistilBERT
# Simulated diffusion inferencing example combining DistilGPT-2 and DistilBERT
# This simulates a diffusion-like process: generate with GPT-2, then iteratively refine by masking and filling with BERT.\
# Author: g023 - https://github.com/g023/ -
import torch
import random
from transformers import GPT2Tokenizer, GPT2LMHeadModel, DistilBertTokenizer, DistilBertForMaskedLM
from transformers import logging
# Suppress warnings
logging.set_verbosity_error()
# Global parameters
GPT2_MODEL_NAME = 'distilgpt2'
BERT_MODEL_NAME = 'distilbert-base-uncased'
MAX_LENGTH = 30
TEMPERATURE = 0.7
TOP_K = 50
TOP_P = 0.9
NO_REPEAT_NGRAM_SIZE = 2
MASK_PROB = 0.15
DIFFUSION_STEPS = 2
EXAMPLE_PROMPT = "The cat sat on the"
# Load models and tokenizers
gpt2_tokenizer = GPT2Tokenizer.from_pretrained(GPT2_MODEL_NAME)
gpt2_model = GPT2LMHeadModel.from_pretrained(GPT2_MODEL_NAME)
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token
bert_tokenizer = DistilBertTokenizer.from_pretrained(BERT_MODEL_NAME)
bert_model = DistilBertForMaskedLM.from_pretrained(BERT_MODEL_NAME)
def generate_with_gpt2(prompt, max_length=MAX_LENGTH):
"""Generate initial text completion using DistilGPT-2."""
inputs = gpt2_tokenizer(prompt, return_tensors='pt')
with torch.no_grad():
outputs = gpt2_model.generate(
**inputs,
max_length=max_length,
num_return_sequences=1,
temperature=TEMPERATURE,
do_sample=True,
top_k=TOP_K,
top_p=TOP_P,
no_repeat_ngram_size=NO_REPEAT_NGRAM_SIZE,
pad_token_id=gpt2_tokenizer.eos_token_id
)
text = gpt2_tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean: keep only ASCII, replace others
text = ''.join(c for c in text if ord(c) < 128)
return text
def mask_text(text, mask_prob=MASK_PROB):
"""Randomly mask tokens in the text for denoising simulation."""
# Tokenize with BERT tokenizer
tokens = bert_tokenizer.tokenize(text)
masked_tokens = []
for token in tokens:
if random.random() < mask_prob and token.isalpha() and token not in ['[CLS]', '[SEP]', '[MASK]']:
masked_tokens.append('[MASK]')
else:
masked_tokens.append(token)
# Decode back to text
masked_text = bert_tokenizer.convert_tokens_to_string(masked_tokens)
return masked_text
def denoise_with_bert(masked_text):
"""Fill [MASK] tokens using DistilBERT."""
inputs = bert_tokenizer(masked_text, return_tensors='pt')
mask_positions = (inputs['input_ids'] == bert_tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
if len(mask_positions) == 0:
return masked_text
with torch.no_grad():
outputs = bert_model(**inputs)
predictions = outputs.logits
# Fill each mask sequentially
filled_tokens = bert_tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
for pos in mask_positions:
pred_logits = predictions[0, pos]
# Get top 5, skip [UNK] if possible
top_k = torch.topk(pred_logits, 5)
top_tokens = [bert_tokenizer.decode([tid.item()]) for tid in top_k.indices]
top_token = next((t for t in top_tokens if t != '[UNK]'), top_tokens[0])
filled_tokens[pos] = top_token
# Remove [CLS] and [SEP]
filled_tokens = filled_tokens[1:-1]
filled_text = bert_tokenizer.convert_tokens_to_string(filled_tokens)
return filled_text
def simulated_diffusion(prompt, steps=DIFFUSION_STEPS, mask_prob=MASK_PROB, gen_length=MAX_LENGTH):
"""Simulate diffusion: generate with GPT-2, then iteratively mask and denoise with BERT."""
# Step 1: Initial generation with GPT-2
current_text = generate_with_gpt2(prompt, max_length=gen_length)
print(f"Initial generation: {current_text}")
# Step 2: Iterative refinement (diffusion steps)
for step in range(steps):
masked = mask_text(current_text, mask_prob=mask_prob)
current_text = denoise_with_bert(masked)
print(f"After step {step+1}: {current_text}")
return current_text
# Example usage
if __name__ == "__main__":
prompt = EXAMPLE_PROMPT
final_text = simulated_diffusion(prompt, steps=DIFFUSION_STEPS)
print(f"\nFinal refined text: {final_text}")
@g023
Copy link
Copy Markdown
Author

g023 commented Dec 8, 2025

distildiffuse

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment