Created
December 8, 2025 13:07
-
-
Save g023/461ff5a63761e99f055710ebfdefa3b5 to your computer and use it in GitHub Desktop.
Simulated diffusion inferencing example combining DistilGPT-2 and DistilBERT
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Simulated diffusion inferencing example combining DistilGPT-2 and DistilBERT | |
| # This simulates a diffusion-like process: generate with GPT-2, then iteratively refine by masking and filling with BERT.\ | |
| # Author: g023 - https://github.com/g023/ - | |
| import torch | |
| import random | |
| from transformers import GPT2Tokenizer, GPT2LMHeadModel, DistilBertTokenizer, DistilBertForMaskedLM | |
| from transformers import logging | |
| # Suppress warnings | |
| logging.set_verbosity_error() | |
| # Global parameters | |
| GPT2_MODEL_NAME = 'distilgpt2' | |
| BERT_MODEL_NAME = 'distilbert-base-uncased' | |
| MAX_LENGTH = 30 | |
| TEMPERATURE = 0.7 | |
| TOP_K = 50 | |
| TOP_P = 0.9 | |
| NO_REPEAT_NGRAM_SIZE = 2 | |
| MASK_PROB = 0.15 | |
| DIFFUSION_STEPS = 2 | |
| EXAMPLE_PROMPT = "The cat sat on the" | |
| # Load models and tokenizers | |
| gpt2_tokenizer = GPT2Tokenizer.from_pretrained(GPT2_MODEL_NAME) | |
| gpt2_model = GPT2LMHeadModel.from_pretrained(GPT2_MODEL_NAME) | |
| gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token | |
| bert_tokenizer = DistilBertTokenizer.from_pretrained(BERT_MODEL_NAME) | |
| bert_model = DistilBertForMaskedLM.from_pretrained(BERT_MODEL_NAME) | |
| def generate_with_gpt2(prompt, max_length=MAX_LENGTH): | |
| """Generate initial text completion using DistilGPT-2.""" | |
| inputs = gpt2_tokenizer(prompt, return_tensors='pt') | |
| with torch.no_grad(): | |
| outputs = gpt2_model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| num_return_sequences=1, | |
| temperature=TEMPERATURE, | |
| do_sample=True, | |
| top_k=TOP_K, | |
| top_p=TOP_P, | |
| no_repeat_ngram_size=NO_REPEAT_NGRAM_SIZE, | |
| pad_token_id=gpt2_tokenizer.eos_token_id | |
| ) | |
| text = gpt2_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Clean: keep only ASCII, replace others | |
| text = ''.join(c for c in text if ord(c) < 128) | |
| return text | |
| def mask_text(text, mask_prob=MASK_PROB): | |
| """Randomly mask tokens in the text for denoising simulation.""" | |
| # Tokenize with BERT tokenizer | |
| tokens = bert_tokenizer.tokenize(text) | |
| masked_tokens = [] | |
| for token in tokens: | |
| if random.random() < mask_prob and token.isalpha() and token not in ['[CLS]', '[SEP]', '[MASK]']: | |
| masked_tokens.append('[MASK]') | |
| else: | |
| masked_tokens.append(token) | |
| # Decode back to text | |
| masked_text = bert_tokenizer.convert_tokens_to_string(masked_tokens) | |
| return masked_text | |
| def denoise_with_bert(masked_text): | |
| """Fill [MASK] tokens using DistilBERT.""" | |
| inputs = bert_tokenizer(masked_text, return_tensors='pt') | |
| mask_positions = (inputs['input_ids'] == bert_tokenizer.mask_token_id).nonzero(as_tuple=True)[1] | |
| if len(mask_positions) == 0: | |
| return masked_text | |
| with torch.no_grad(): | |
| outputs = bert_model(**inputs) | |
| predictions = outputs.logits | |
| # Fill each mask sequentially | |
| filled_tokens = bert_tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) | |
| for pos in mask_positions: | |
| pred_logits = predictions[0, pos] | |
| # Get top 5, skip [UNK] if possible | |
| top_k = torch.topk(pred_logits, 5) | |
| top_tokens = [bert_tokenizer.decode([tid.item()]) for tid in top_k.indices] | |
| top_token = next((t for t in top_tokens if t != '[UNK]'), top_tokens[0]) | |
| filled_tokens[pos] = top_token | |
| # Remove [CLS] and [SEP] | |
| filled_tokens = filled_tokens[1:-1] | |
| filled_text = bert_tokenizer.convert_tokens_to_string(filled_tokens) | |
| return filled_text | |
| def simulated_diffusion(prompt, steps=DIFFUSION_STEPS, mask_prob=MASK_PROB, gen_length=MAX_LENGTH): | |
| """Simulate diffusion: generate with GPT-2, then iteratively mask and denoise with BERT.""" | |
| # Step 1: Initial generation with GPT-2 | |
| current_text = generate_with_gpt2(prompt, max_length=gen_length) | |
| print(f"Initial generation: {current_text}") | |
| # Step 2: Iterative refinement (diffusion steps) | |
| for step in range(steps): | |
| masked = mask_text(current_text, mask_prob=mask_prob) | |
| current_text = denoise_with_bert(masked) | |
| print(f"After step {step+1}: {current_text}") | |
| return current_text | |
| # Example usage | |
| if __name__ == "__main__": | |
| prompt = EXAMPLE_PROMPT | |
| final_text = simulated_diffusion(prompt, steps=DIFFUSION_STEPS) | |
| print(f"\nFinal refined text: {final_text}") |
Author
g023
commented
Dec 8, 2025
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment