Skip to content

Instantly share code, notes, and snippets.

@lucasnewman
Created June 17, 2026 19:44
Show Gist options
  • Select an option

  • Save lucasnewman/ee6d792db0b5f23f8ad343c7a564944c to your computer and use it in GitHub Desktop.

Select an option

Save lucasnewman/ee6d792db0b5f23f8ad343c7a564944c to your computer and use it in GitHub Desktop.
diff --git a/mlx_vlm/generate/diffusion.py b/mlx_vlm/generate/diffusion.py
index ae56794..31e7ba4 100644
--- a/mlx_vlm/generate/diffusion.py
+++ b/mlx_vlm/generate/diffusion.py
@@ -4,6 +4,7 @@ import logging
import os
import shutil
import time
+from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Optional, Tuple
import mlx.core as mx
@@ -31,6 +32,47 @@ DEFAULT_DIFFUSION_MAX_DENOISING_STEPS = 48
DEFAULT_DIFFUSION_UNMASKING_WIDTH = 0
+@dataclass
+class DiffusionCanvasDraft:
+ text: str
+ step: int
+ total_steps: int
+
+
+@dataclass
+class DiffusionCanvasResult:
+ canvas: mx.array
+ canvas_tokens: int
+ denoising_steps: int
+ work_tokens: int
+
+
+@dataclass
+class DiffusionCanvasDenoiseContext:
+ model: nn.Module
+ tokenizer: PreTrainedTokenizer
+ input_dtype: Any
+ batch_size: int
+ canvas_length: int
+ vocab_size: int
+ kv_cache: Any
+ decoder_attention_mask: Optional[mx.array]
+ mask_mapping: Any
+ max_denoising_steps: int
+ temperature: float
+ temperature_config: Dict[str, Any]
+ sampler: str
+ threshold: float
+ entropy_bound: float
+ stopping_config: Optional[Dict[str, Any]]
+ soft_embedding_weight: mx.array
+ compile_graph: bool
+ show_unmasking: bool
+ unmasking_interval: int
+ unmasking_width: int
+ skip_special_token_ids: Any
+
+
def _diffusion_display_limit(requested_width: Optional[int] = None) -> Optional[int]:
terminal_width = shutil.get_terminal_size((120, 20)).columns
requested_width = requested_width or DEFAULT_DIFFUSION_UNMASKING_WIDTH
@@ -527,6 +569,219 @@ def _decode_diffusion_masked_draft(
return escape_carriage_returns(text)
+def _standard_diffusion_canvas_denoiser(
+ ctx: DiffusionCanvasDenoiseContext,
+) -> Generator[DiffusionCanvasDraft | DiffusionCanvasResult, None, None]:
+ current_canvas = _diffusion_initialize_canvas(
+ ctx.batch_size,
+ ctx.canvas_length,
+ ctx.vocab_size,
+ ctx.input_dtype,
+ )
+ draft_reveal_mask = mx.zeros(current_canvas.shape, dtype=mx.bool_)
+ draft_canvas = current_canvas
+ accepted_canvas = current_canvas
+ argmax_canvas = current_canvas
+ self_conditioning_embeddings = None
+ decoder_logits_without_sc, decoder_logits_with_sc = (
+ _make_diffusion_decoder_logits_fns(
+ ctx.model,
+ ctx.kv_cache,
+ ctx.mask_mapping,
+ compile_graph=ctx.compile_graph,
+ )
+ )
+ diffusion_history: List[mx.array] = []
+ denoising_steps = 0
+ compile_graph = ctx.compile_graph
+
+ if ctx.show_unmasking:
+ draft_text = _decode_diffusion_masked_draft(
+ ctx.tokenizer,
+ [int(token_id) for token_id in draft_canvas[0].tolist()],
+ [False] * ctx.canvas_length,
+ ctx.skip_special_token_ids,
+ max_chars=ctx.unmasking_width,
+ )
+ yield DiffusionCanvasDraft(
+ text=draft_text,
+ step=0,
+ total_steps=ctx.max_denoising_steps,
+ )
+
+ for cur_step in reversed(range(1, ctx.max_denoising_steps + 1)):
+ denoising_steps += 1
+ try:
+ if self_conditioning_embeddings is None:
+ processed_logits = decoder_logits_without_sc(current_canvas)
+ else:
+ processed_logits = decoder_logits_with_sc(
+ current_canvas,
+ self_conditioning_embeddings,
+ )
+ except Exception as exc:
+ if not compile_graph:
+ raise
+ logger.warning(
+ "Diffusion decoder compilation failed; falling back "
+ "to the eager path: %s",
+ exc,
+ )
+ compile_graph = False
+ ctx.compile_graph = False
+ decoder_logits_without_sc, decoder_logits_with_sc = (
+ _make_diffusion_decoder_logits_fns(
+ ctx.model,
+ ctx.kv_cache,
+ ctx.mask_mapping,
+ compile_graph=False,
+ )
+ )
+ if self_conditioning_embeddings is None:
+ processed_logits = decoder_logits_without_sc(current_canvas)
+ else:
+ processed_logits = decoder_logits_with_sc(
+ current_canvas,
+ self_conditioning_embeddings,
+ )
+ schedule_temperature = _diffusion_linear_temperature(
+ cur_step,
+ ctx.max_denoising_steps,
+ ctx.temperature_config,
+ )
+ if schedule_temperature is not None:
+ processed_logits = processed_logits / schedule_temperature
+
+ argmax_canvas = mx.argmax(processed_logits, axis=-1).astype(ctx.input_dtype)
+ if cur_step == 1 and not ctx.show_unmasking:
+ break
+
+ denoiser_canvas = (
+ argmax_canvas
+ if ctx.temperature <= 0
+ else _diffusion_sample_canvas(
+ processed_logits,
+ ctx.input_dtype,
+ ctx.temperature,
+ )
+ )
+
+ if ctx.sampler == "entropy-bound":
+ if cur_step > 1:
+ token_entropy, next_self_conditioning_embeddings = (
+ _diffusion_entropy_and_soft_embeddings(
+ processed_logits,
+ ctx.soft_embedding_weight,
+ ctx.model.model.decoder.embed_scale,
+ )
+ )
+ else:
+ token_entropy = _diffusion_token_entropy(processed_logits)
+ next_self_conditioning_embeddings = None
+ acceptance_mask = _diffusion_entropy_transfer_mask(
+ token_entropy,
+ ctx.entropy_bound,
+ )
+ accepted_canvas = mx.where(
+ acceptance_mask,
+ denoiser_canvas,
+ current_canvas,
+ )
+ current_canvas = mx.where(
+ acceptance_mask,
+ accepted_canvas,
+ _diffusion_initialize_canvas(
+ ctx.batch_size,
+ ctx.canvas_length,
+ ctx.vocab_size,
+ ctx.input_dtype,
+ ),
+ )
+ draft_reveal_mask = acceptance_mask
+ draft_canvas = argmax_canvas
+ else:
+ next_self_conditioning_embeddings = None
+ unrevealed_mask = ~draft_reveal_mask
+ confidence = _diffusion_token_probability(
+ processed_logits,
+ denoiser_canvas,
+ )
+ acceptance_mask = _diffusion_confidence_transfer_mask(
+ confidence,
+ unrevealed_mask,
+ ctx.threshold,
+ force_all=cur_step == 1,
+ )
+ accepted_canvas = mx.where(
+ acceptance_mask,
+ denoiser_canvas,
+ draft_canvas,
+ )
+ current_canvas = mx.where(
+ draft_reveal_mask | acceptance_mask,
+ accepted_canvas,
+ _diffusion_initialize_canvas(
+ ctx.batch_size,
+ ctx.canvas_length,
+ ctx.vocab_size,
+ ctx.input_dtype,
+ ),
+ )
+ draft_reveal_mask = draft_reveal_mask | acceptance_mask
+ draft_canvas = mx.where(acceptance_mask, accepted_canvas, draft_canvas)
+
+ displayed_step = ctx.max_denoising_steps - cur_step + 1
+ should_show_unmasking = ctx.show_unmasking and (
+ displayed_step == 1
+ or cur_step == 1
+ or displayed_step % ctx.unmasking_interval == 0
+ )
+ if should_show_unmasking:
+ mx.eval(draft_canvas, draft_reveal_mask)
+ draft_text = _decode_diffusion_masked_draft(
+ ctx.tokenizer,
+ [int(token_id) for token_id in draft_canvas[0].tolist()],
+ [bool(v) for v in draft_reveal_mask[0].tolist()],
+ ctx.skip_special_token_ids,
+ max_chars=ctx.unmasking_width,
+ )
+ yield DiffusionCanvasDraft(
+ text=draft_text,
+ step=displayed_step,
+ total_steps=ctx.max_denoising_steps,
+ )
+
+ if ctx.sampler == "confidence-threshold" and bool(
+ mx.all(draft_reveal_mask).item()
+ ):
+ accepted_canvas = draft_canvas
+ break
+
+ if _diffusion_stable_and_confident(
+ argmax_canvas,
+ processed_logits,
+ diffusion_history,
+ ctx.stopping_config,
+ ):
+ break
+
+ if cur_step > 1:
+ if next_self_conditioning_embeddings is None:
+ next_self_conditioning_embeddings = _diffusion_soft_embeddings(
+ processed_logits,
+ ctx.soft_embedding_weight,
+ ctx.model.model.decoder.embed_scale,
+ )
+ self_conditioning_embeddings = next_self_conditioning_embeddings
+
+ yield DiffusionCanvasResult(
+ canvas=argmax_canvas,
+ canvas_tokens=ctx.canvas_length,
+ denoising_steps=denoising_steps,
+ work_tokens=ctx.canvas_length * denoising_steps,
+ )
+
+
def stream_diffusion_generate(
model: nn.Module,
processor: PreTrainedTokenizer,
@@ -551,6 +806,8 @@ def stream_diffusion_generate(
diffusion_unmasking_width: int = DEFAULT_DIFFUSION_UNMASKING_WIDTH,
mm_token_type_ids: Optional[mx.array] = None,
prefill_step_size: Optional[int] = None,
+ diffusion_canvas_denoiser: Optional[Any] = None,
+ diffusion_repeat_guard: Optional[int] = None,
) -> Generator[GenerationResult, None, None]:
if input_ids.shape[0] != 1:
raise ValueError(
@@ -604,6 +861,9 @@ def stream_diffusion_generate(
prefill_step_size = int(prefill_step_size)
if prefill_step_size <= 0:
raise ValueError("prefill_step_size must be a positive integer.")
+ repeat_guard = int(diffusion_repeat_guard or 0)
+ if repeat_guard < 0:
+ raise ValueError("diffusion_repeat_guard must be non-negative.")
sampler_config = _diffusion_config_dict(
generation_config.get("sampler_config", None)
@@ -690,6 +950,9 @@ def stream_diffusion_generate(
current_canvas = None
stopped = False
stop_reason = "length"
+ prev_emit_id = None
+ repeat_run = 0
+ canvas_denoiser = diffusion_canvas_denoiser or _standard_diffusion_canvas_denoiser
def make_result(
text: str,
@@ -778,227 +1041,62 @@ def stream_diffusion_generate(
if decoder_attention_mask is not None
else None
)
- current_canvas = _diffusion_initialize_canvas(
- batch_size,
- canvas_length,
- vocab_size,
- input_ids.dtype,
- )
- draft_reveal_mask = mx.zeros(current_canvas.shape, dtype=mx.bool_)
- draft_canvas = current_canvas
- accepted_canvas = current_canvas
- argmax_canvas = current_canvas
- self_conditioning_embeddings = None
mask_mapping = model.model.decoder._make_decoder_masks(
- current_canvas[..., None],
+ mx.zeros((batch_size, canvas_length, 1), dtype=input_ids.dtype),
kv_cache,
current_decoder_attention_mask,
)
- decoder_logits_without_sc, decoder_logits_with_sc = (
- _make_diffusion_decoder_logits_fns(
- model,
- kv_cache,
- mask_mapping,
- compile_graph=diffusion_compile,
- )
- )
- diffusion_history: List[mx.array] = []
- denoising_steps_this_canvas = 0
-
- if diffusion_show_unmasking:
- draft_text = _decode_diffusion_masked_draft(
- tokenizer,
- [int(token_id) for token_id in draft_canvas[0].tolist()],
- [False] * canvas_length,
- skip_special_token_ids,
- max_chars=diffusion_unmasking_width,
- )
- yield make_result(
- "",
- is_draft=True,
- draft_text=draft_text,
- diffusion_step=0,
- diffusion_total_steps=max_denoising_steps,
- diffusion_canvas_index=canvas_index,
- )
-
- for cur_step in reversed(range(1, max_denoising_steps + 1)):
- denoising_steps_this_canvas += 1
- try:
- if self_conditioning_embeddings is None:
- processed_logits = decoder_logits_without_sc(current_canvas)
- else:
- processed_logits = decoder_logits_with_sc(
- current_canvas,
- self_conditioning_embeddings,
- )
- except Exception as exc:
- if not diffusion_compile:
- raise
- logger.warning(
- "Diffusion decoder compilation failed; falling back "
- "to the eager path: %s",
- exc,
- )
- diffusion_compile = False
- decoder_logits_without_sc, decoder_logits_with_sc = (
- _make_diffusion_decoder_logits_fns(
- model,
- kv_cache,
- mask_mapping,
- compile_graph=False,
- )
- )
- if self_conditioning_embeddings is None:
- processed_logits = decoder_logits_without_sc(current_canvas)
- else:
- processed_logits = decoder_logits_with_sc(
- current_canvas,
- self_conditioning_embeddings,
- )
- schedule_temperature = _diffusion_linear_temperature(
- cur_step,
- max_denoising_steps,
- temperature_config,
- )
- if schedule_temperature is not None:
- processed_logits = processed_logits / schedule_temperature
-
- argmax_canvas = mx.argmax(processed_logits, axis=-1).astype(
- input_ids.dtype
- )
- if cur_step == 1 and not diffusion_show_unmasking:
- break
-
- denoiser_canvas = (
- argmax_canvas
- if temperature <= 0
- else _diffusion_sample_canvas(
- processed_logits,
- input_ids.dtype,
- temperature,
- )
- )
- if diffusion_sampler == "entropy-bound":
- if cur_step > 1:
- token_entropy, next_self_conditioning_embeddings = (
- _diffusion_entropy_and_soft_embeddings(
- processed_logits,
- soft_embedding_weight,
- model.model.decoder.embed_scale,
- )
- )
- else:
- token_entropy = _diffusion_token_entropy(processed_logits)
- next_self_conditioning_embeddings = None
- acceptance_mask = _diffusion_entropy_transfer_mask(
- token_entropy,
- entropy_bound,
- )
- accepted_canvas = mx.where(
- acceptance_mask,
- denoiser_canvas,
- current_canvas,
- )
- current_canvas = mx.where(
- acceptance_mask,
- accepted_canvas,
- _diffusion_initialize_canvas(
- batch_size,
- canvas_length,
- vocab_size,
- input_ids.dtype,
- ),
- )
- draft_reveal_mask = acceptance_mask
- draft_canvas = argmax_canvas
- else:
- next_self_conditioning_embeddings = None
- unrevealed_mask = ~draft_reveal_mask
- confidence = _diffusion_token_probability(
- processed_logits,
- denoiser_canvas,
- )
- acceptance_mask = _diffusion_confidence_transfer_mask(
- confidence,
- unrevealed_mask,
- diffusion_threshold,
- force_all=cur_step == 1,
- )
- accepted_canvas = mx.where(
- acceptance_mask,
- denoiser_canvas,
- draft_canvas,
- )
- current_canvas = mx.where(
- draft_reveal_mask | acceptance_mask,
- accepted_canvas,
- _diffusion_initialize_canvas(
- batch_size,
- canvas_length,
- vocab_size,
- input_ids.dtype,
- ),
- )
- draft_reveal_mask = draft_reveal_mask | acceptance_mask
- draft_canvas = mx.where(
- acceptance_mask, accepted_canvas, draft_canvas
- )
+ ctx = DiffusionCanvasDenoiseContext(
+ model=model,
+ tokenizer=tokenizer,
+ input_dtype=input_ids.dtype,
+ batch_size=batch_size,
+ canvas_length=canvas_length,
+ vocab_size=vocab_size,
+ kv_cache=kv_cache,
+ decoder_attention_mask=current_decoder_attention_mask,
+ mask_mapping=mask_mapping,
+ max_denoising_steps=max_denoising_steps,
+ temperature=temperature,
+ temperature_config=temperature_config,
+ sampler=diffusion_sampler,
+ threshold=diffusion_threshold,
+ entropy_bound=entropy_bound,
+ stopping_config=diffusion_stopping_config,
+ soft_embedding_weight=soft_embedding_weight,
+ compile_graph=diffusion_compile,
+ show_unmasking=diffusion_show_unmasking,
+ unmasking_interval=diffusion_unmasking_interval,
+ unmasking_width=diffusion_unmasking_width,
+ skip_special_token_ids=skip_special_token_ids,
+ )
- displayed_step = max_denoising_steps - cur_step + 1
- should_show_unmasking = diffusion_show_unmasking and (
- displayed_step == 1
- or cur_step == 1
- or displayed_step % diffusion_unmasking_interval == 0
- )
- if should_show_unmasking:
- mx.eval(draft_canvas, draft_reveal_mask)
- draft_text = _decode_diffusion_masked_draft(
- tokenizer,
- [int(token_id) for token_id in draft_canvas[0].tolist()],
- [bool(v) for v in draft_reveal_mask[0].tolist()],
- skip_special_token_ids,
- max_chars=diffusion_unmasking_width,
- )
+ canvas_result = None
+ for event in canvas_denoiser(ctx):
+ if isinstance(event, DiffusionCanvasDraft):
yield make_result(
"",
is_draft=True,
- draft_text=draft_text,
- diffusion_step=displayed_step,
- diffusion_total_steps=max_denoising_steps,
+ draft_text=event.text,
+ diffusion_step=event.step,
+ diffusion_total_steps=event.total_steps,
diffusion_canvas_index=canvas_index,
)
+ continue
+ canvas_result = event
- if diffusion_sampler == "confidence-threshold" and bool(
- mx.all(draft_reveal_mask).item()
- ):
- accepted_canvas = draft_canvas
- break
-
- if _diffusion_stable_and_confident(
- argmax_canvas,
- processed_logits,
- diffusion_history,
- diffusion_stopping_config,
- ):
- break
+ diffusion_compile = ctx.compile_graph
+ if canvas_result is None:
+ raise RuntimeError("Diffusion canvas denoiser did not return a canvas.")
- if cur_step > 1:
- if next_self_conditioning_embeddings is None:
- next_self_conditioning_embeddings = _diffusion_soft_embeddings(
- processed_logits,
- soft_embedding_weight,
- model.model.decoder.embed_scale,
- )
- self_conditioning_embeddings = next_self_conditioning_embeddings
-
- current_canvas = argmax_canvas
- diffusion_canvas_tokens += canvas_length
- diffusion_denoising_steps += denoising_steps_this_canvas
- diffusion_work_tokens += canvas_length * denoising_steps_this_canvas
+ current_canvas = canvas_result.canvas
+ diffusion_canvas_tokens += canvas_result.canvas_tokens
+ diffusion_denoising_steps += canvas_result.denoising_steps
+ diffusion_work_tokens += canvas_result.work_tokens
mx.eval(current_canvas)
+ committed_length = current_canvas.shape[-1]
for token_id in current_canvas[0].tolist():
last_token = int(token_id)
generated_tokens += 1
@@ -1008,6 +1106,14 @@ def stream_diffusion_generate(
stop_reason = "stop"
break
+ if repeat_guard:
+ repeat_run = repeat_run + 1 if last_token == prev_emit_id else 1
+ prev_emit_id = last_token
+ if repeat_run >= repeat_guard:
+ stopped = True
+ stop_reason = "repetition"
+ break
+
detokenizer.add_token(
last_token, skip_special_token_ids=skip_special_token_ids
)
@@ -1034,14 +1140,14 @@ def stream_diffusion_generate(
if use_static_cache:
decoder_attention_mask[
- :, cached_sequence_length : cached_sequence_length + canvas_length
+ :, cached_sequence_length : cached_sequence_length + committed_length
] = True
- cached_sequence_length += canvas_length
+ cached_sequence_length += committed_length
elif decoder_attention_mask is not None:
decoder_attention_mask = mx.concatenate(
[
decoder_attention_mask,
- mx.ones((batch_size, canvas_length), dtype=mx.bool_),
+ mx.ones((batch_size, committed_length), dtype=mx.bool_),
],
axis=-1,
)
@@ -1077,6 +1183,15 @@ def stream_diffusion_generate_from_kwargs(
diffusion_unmasking_width = kwargs.pop(
"diffusion_unmasking_width", DEFAULT_DIFFUSION_UNMASKING_WIDTH
)
+ diffusion_turbo = kwargs.pop("diffusion_turbo", False)
+ diffusion_canvas_denoiser = None
+ diffusion_repeat_guard = None
+ if diffusion_turbo:
+ from .diffusion_turbo import make_diffusion_turbo_denoiser_from_kwargs
+
+ diffusion_canvas_denoiser, diffusion_repeat_guard = (
+ make_diffusion_turbo_denoiser_from_kwargs(kwargs)
+ )
mm_token_type_ids = kwargs.pop("mm_token_type_ids", None)
prefill_step_size = kwargs.pop("prefill_step_size", None)
with wired_limit(model, [generation_stream]):
@@ -1103,5 +1218,7 @@ def stream_diffusion_generate_from_kwargs(
diffusion_unmasking_width=diffusion_unmasking_width,
mm_token_type_ids=mm_token_type_ids,
prefill_step_size=prefill_step_size,
+ diffusion_canvas_denoiser=diffusion_canvas_denoiser,
+ diffusion_repeat_guard=diffusion_repeat_guard,
)
mx.clear_cache()
diff --git a/mlx_vlm/generate/diffusion_turbo.py b/mlx_vlm/generate/diffusion_turbo.py
new file mode 100644
index 0000000..a31ab1a
--- /dev/null
+++ b/mlx_vlm/generate/diffusion_turbo.py
@@ -0,0 +1,744 @@
+"""Turbo per-canvas denoisers for block-diffusion generation.
+
+The shared stream lifecycle stays in ``diffusion.py``: prompt prefill, decoder
+mask construction, cache updates, token emission, draft emission, and result
+accounting. This module only supplies optional denoising policies that trade
+fidelity for throughput.
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Any, Generator, Optional
+
+import mlx.core as mx
+
+from .diffusion import (
+ DiffusionCanvasDenoiseContext,
+ DiffusionCanvasDraft,
+ DiffusionCanvasResult,
+ _decode_diffusion_masked_draft,
+ _diffusion_initialize_canvas,
+)
+
+DEFAULT_TURBO_TOPK = 64
+_SHAPE_BUCKETS = (16, 32, 48, 64, 96, 128, 160, 192, 224, 256)
+
+
+@dataclass
+class TurboDiffusionConfig:
+ topk: int = DEFAULT_TURBO_TOPK
+ monotone: bool = True
+ eos_early_stop: bool = True
+ quota: float = 0.0
+ steps: Optional[int] = None
+ accept: str = "entropy-bound"
+ threshold: Any = 0.9
+ window: int = 0
+ repair: bool = False
+ compact: bool = False
+ repeat_guard: int = 16
+ entropy_bound: Optional[float] = None
+
+
+def make_diffusion_turbo_denoiser_from_kwargs(kwargs: dict):
+ no_monotone = bool(kwargs.pop("turbo_no_monotone", False))
+ config = TurboDiffusionConfig(
+ topk=int(kwargs.pop("turbo_topk", DEFAULT_TURBO_TOPK)),
+ monotone=bool(kwargs.pop("turbo_monotone", not no_monotone)),
+ eos_early_stop=bool(kwargs.pop("turbo_eos_early_stop", True)),
+ quota=float(kwargs.pop("turbo_quota", 0.0)),
+ steps=kwargs.pop("turbo_steps", None),
+ accept=kwargs.pop("turbo_accept", "entropy-bound"),
+ threshold=kwargs.pop("turbo_threshold", 0.9),
+ window=int(kwargs.pop("turbo_window", 0)),
+ repair=bool(kwargs.pop("turbo_repair", False)),
+ compact=bool(kwargs.pop("turbo_compact", False)),
+ repeat_guard=int(kwargs.pop("turbo_repeat_guard", 16) or 0),
+ entropy_bound=kwargs.pop("entropy_bound", None),
+ )
+ if config.topk <= 0:
+ raise ValueError("turbo_topk must be a positive integer.")
+ if config.steps is not None:
+ config.steps = int(config.steps)
+ if config.steps <= 0:
+ raise ValueError("turbo_steps must be a positive integer.")
+ if config.accept not in ("entropy-bound", "confidence"):
+ raise ValueError(f"Unsupported turbo_accept: {config.accept!r}.")
+ if config.repeat_guard < 0:
+ raise ValueError("turbo_repeat_guard must be non-negative.")
+ if config.entropy_bound is not None:
+ config.entropy_bound = float(config.entropy_bound)
+ return TurboDiffusionDenoiser(config), config.repeat_guard
+
+
+def _bucket_size(n: int) -> int:
+ for bucket in _SHAPE_BUCKETS:
+ if n <= bucket:
+ return bucket
+ return _SHAPE_BUCKETS[-1]
+
+
+def _turbo_temperature(ctx: DiffusionCanvasDenoiseContext, step: int) -> mx.array:
+ t_min = float(ctx.temperature_config.get("t_min", 0.4))
+ t_max = float(ctx.temperature_config.get("t_max", 0.8))
+ frac = 1.0 - step / max(ctx.max_denoising_steps - 1, 1)
+ return mx.array(t_min + (t_max - t_min) * frac, dtype=mx.float32)
+
+
+def _topk_logits(raw_logits: mx.array, k: int, chunk: int = 64):
+ *lead, vocab_size = raw_logits.shape
+ k = min(int(k), vocab_size)
+ if k >= vocab_size or vocab_size % chunk != 0 or k > (vocab_size // chunk):
+ idx = mx.argpartition(raw_logits, kth=-k, axis=-1)[..., -k:]
+ vals = mx.take_along_axis(raw_logits, idx, axis=-1)
+ return vals, idx
+
+ num_chunks = vocab_size // chunk
+ chunked = raw_logits.reshape(*lead, num_chunks, chunk)
+ chunk_max = chunked.max(axis=-1)
+ chunk_idx = mx.argpartition(chunk_max, kth=-k, axis=-1)[..., -k:]
+ sub = mx.take_along_axis(chunked, mx.expand_dims(chunk_idx, -1), axis=-2)
+ sub = sub.reshape(*lead, k * chunk)
+ sub_idx = mx.argpartition(sub, kth=-k, axis=-1)[..., -k:]
+ vals = mx.take_along_axis(sub, sub_idx, axis=-1)
+ selected_chunk = mx.take_along_axis(chunk_idx, sub_idx // chunk, axis=-1)
+ idx = selected_chunk * chunk + (sub_idx % chunk)
+ return vals, idx
+
+
+def _topk_postprocess(vals: mx.array, softcap: mx.array, temp: mx.array):
+ capped = mx.tanh(vals.astype(mx.float32) / softcap) * softcap
+ scaled = capped / temp
+ log_probs = scaled - mx.logsumexp(scaled, axis=-1, keepdims=True)
+ probs = mx.exp(log_probs)
+ entropy = -mx.sum(probs * log_probs, axis=-1)
+ return probs, entropy
+
+
+def _entropy_transfer_mask(
+ entropy: mx.array,
+ entropy_bound: mx.array,
+ quota: int = 0,
+) -> mx.array:
+ sorted_indices = mx.argsort(entropy, axis=-1)
+ sorted_entropy = mx.take_along_axis(entropy, sorted_indices, axis=-1)
+ cumulative_entropy = mx.cumsum(sorted_entropy, axis=-1)
+ cumulative_maximum = mx.cummax(sorted_entropy, axis=-1)
+ sorted_mask = (cumulative_entropy - cumulative_maximum) <= entropy_bound
+ if quota > 0:
+ rank = mx.arange(entropy.shape[-1])[None, :]
+ sorted_mask = sorted_mask | (rank < quota)
+ mask = mx.zeros_like(sorted_mask)
+ return mx.put_along_axis(mask, sorted_indices, sorted_mask, axis=-1)
+
+
+class _TurboSelfConditioner:
+ def __init__(self, decoder):
+ self.embed_tokens = decoder.embed_tokens
+ self.embed_scale = decoder.embed_scale
+
+ def soft_embeddings(self, probs: mx.array, idx: mx.array) -> mx.array:
+ rows = self.embed_tokens(idx)
+ soft = (probs[..., None].astype(rows.dtype) * rows).sum(axis=-2)
+ return soft * self.embed_scale
+
+
+class TurboDiffusionDenoiser:
+ def __init__(self, config: TurboDiffusionConfig):
+ self.config = config
+
+ def __call__(
+ self,
+ ctx: DiffusionCanvasDenoiseContext,
+ ) -> Generator[DiffusionCanvasDraft | DiffusionCanvasResult, None, None]:
+ if ctx.batch_size != 1:
+ raise ValueError("Turbo diffusion generation only supports batch size 1.")
+ if self.config.compact:
+ yield from self._compact(ctx)
+ return
+ yield from self._full_canvas(ctx)
+
+ @property
+ def _entropy_bound(self):
+ return self.config.entropy_bound
+
+ def _softcap(self, ctx: DiffusionCanvasDenoiseContext) -> mx.array:
+ text_config = ctx.model.config.text_config
+ softcap = (
+ getattr(ctx.model, "final_logit_softcapping", None)
+ or getattr(text_config, "final_logit_softcapping", 30.0)
+ )
+ return mx.array(float(softcap), dtype=mx.float32)
+
+ def _sample_from_topk(self, probs: mx.array, idx: mx.array, temperature: float):
+ if temperature <= 0:
+ selected = mx.argmax(probs, axis=-1)
+ else:
+ selected = mx.random.categorical(
+ mx.log(probs + 1e-20) / max(temperature, 1e-5)
+ )
+ return mx.take_along_axis(idx, selected[..., None], axis=-1).squeeze(-1)
+
+ def _accept_mask(
+ self,
+ *,
+ probs: mx.array,
+ entropy: mx.array,
+ step: int,
+ entropy_bound: mx.array,
+ quota: int,
+ flush: bool,
+ active_positions=None,
+ canvas_length: Optional[int] = None,
+ ) -> mx.array:
+ if flush:
+ return mx.ones(entropy.shape, dtype=mx.bool_)
+ if self.config.accept == "confidence":
+ if isinstance(self.config.threshold, (tuple, list)):
+ hi, lo, switch_step = self.config.threshold
+ threshold = hi if step < int(switch_step) else lo
+ else:
+ threshold = self.config.threshold
+ accept = probs.max(axis=-1) >= float(threshold)
+ if self.config.window > 0 and active_positions is not None:
+ if canvas_length is None or len(active_positions) == canvas_length:
+ frontier_limit = self.config.window
+ pos_of_active = mx.arange(canvas_length)[None, :]
+ else:
+ frontier_limit = active_positions[0] + self.config.window
+ pos_of_active = mx.array(active_positions, dtype=mx.int32)[None, :]
+ accept = accept & (pos_of_active < frontier_limit)
+ if not bool(mx.any(accept).item()):
+ best = mx.argmax(probs.max(axis=-1), axis=-1)
+ accept = mx.arange(probs.shape[-2])[None, :] == best[:, None]
+ return accept
+ return _entropy_transfer_mask(entropy, entropy_bound, quota)
+
+ def _draft(
+ self,
+ ctx: DiffusionCanvasDenoiseContext,
+ token_ids,
+ reveal_mask,
+ step: int,
+ ) -> DiffusionCanvasDraft:
+ return DiffusionCanvasDraft(
+ text=_decode_diffusion_masked_draft(
+ ctx.tokenizer,
+ [int(token_id) for token_id in token_ids],
+ [bool(v) for v in reveal_mask],
+ ctx.skip_special_token_ids,
+ max_chars=ctx.unmasking_width,
+ ),
+ step=step,
+ total_steps=self.config.steps or ctx.max_denoising_steps,
+ )
+
+ def _full_canvas(
+ self,
+ ctx: DiffusionCanvasDenoiseContext,
+ ) -> Generator[DiffusionCanvasDraft | DiffusionCanvasResult, None, None]:
+ decoder = ctx.model.model.decoder
+ self_conditioner = _TurboSelfConditioner(decoder)
+ softcap = self._softcap(ctx)
+ entropy_bound = mx.array(
+ self._entropy_bound if self._entropy_bound is not None else ctx.entropy_bound,
+ dtype=mx.float32,
+ )
+ stability_threshold = int(
+ (ctx.stopping_config or {}).get("stability_threshold", 1)
+ )
+ confidence_threshold = float(
+ (ctx.stopping_config or {}).get("confidence_threshold", 0.005)
+ )
+
+ canvas = _diffusion_initialize_canvas(
+ ctx.batch_size,
+ ctx.canvas_length,
+ ctx.vocab_size,
+ mx.int32,
+ )
+ frozen = mx.zeros((ctx.batch_size, ctx.canvas_length), dtype=mx.bool_)
+ committed = canvas
+ sc_embeddings = None
+ prev_argmax = None
+ stable_count = 0
+ final_canvas = None
+ emit_length = ctx.canvas_length
+ active_positions = list(range(ctx.canvas_length))
+ steps = 0
+ work_tokens = 0
+
+ if ctx.show_unmasking:
+ yield self._draft(
+ ctx,
+ canvas[0].tolist(),
+ [False] * ctx.canvas_length,
+ 0,
+ )
+
+ for step in range(ctx.max_denoising_steps):
+ steps += 1
+ inputs_embeds = decoder.embed_tokens(canvas) * decoder.embed_scale
+ soft = (
+ sc_embeddings.astype(inputs_embeds.dtype)
+ if sc_embeddings is not None
+ else mx.zeros_like(inputs_embeds)
+ )
+ h = decoder.self_conditioning(inputs_embeds, soft)
+ cache_list = ctx.kv_cache or [None] * len(decoder.layers)
+ offset = (
+ cache_list[0].offset
+ if cache_list and getattr(cache_list[0], "keys", None) is not None
+ else 0
+ )
+ for layer, cache in zip(decoder.layers, cache_list):
+ h = layer(
+ h,
+ ctx.mask_mapping.get(layer.layer_type),
+ cache,
+ decoder=True,
+ offset=offset,
+ )
+ h = decoder.norm(h)
+
+ if self.config.monotone and step > 0:
+ active_idx = mx.array(active_positions, dtype=mx.int32)
+ h_active = h[:, active_idx, :]
+ else:
+ active_idx = None
+ h_active = h
+ work_tokens += int(h_active.shape[1])
+
+ raw_logits = decoder.embed_tokens.as_linear(h_active)
+ vals, idx = _topk_logits(raw_logits, self.config.topk)
+ probs, entropy = _topk_postprocess(
+ vals,
+ softcap,
+ _turbo_temperature(ctx, step),
+ )
+ proposal = self._sample_from_topk(probs, idx, ctx.temperature).astype(
+ mx.int32
+ )
+ top1 = (
+ mx.take_along_axis(
+ idx,
+ mx.argmax(probs, axis=-1)[..., None],
+ axis=-1,
+ )
+ .squeeze(-1)
+ .astype(mx.int32)
+ )
+
+ quota = 0
+ if self.config.quota > 0:
+ quota = max(1, int(entropy.shape[-1] * self.config.quota))
+ flush = self.config.steps is not None and step >= self.config.steps - 1
+ accept = self._accept_mask(
+ probs=probs,
+ entropy=entropy,
+ step=step,
+ entropy_bound=entropy_bound,
+ quota=quota,
+ flush=flush,
+ active_positions=active_positions,
+ canvas_length=ctx.canvas_length,
+ )
+
+ if self.config.monotone:
+ if active_idx is None:
+ committed = mx.where(accept, proposal, committed)
+ canvas = mx.where(
+ accept,
+ committed,
+ mx.random.randint(0, ctx.vocab_size, canvas.shape),
+ )
+ frozen = frozen | accept
+ argmax_full = top1
+ else:
+ committed_active = mx.where(
+ accept,
+ proposal,
+ mx.take_along_axis(committed, active_idx[None, :], axis=-1),
+ )
+ committed = mx.put_along_axis(
+ committed,
+ active_idx[None, :],
+ committed_active,
+ axis=-1,
+ )
+ canvas_active = mx.where(
+ accept,
+ committed_active,
+ mx.random.randint(0, ctx.vocab_size, accept.shape),
+ )
+ canvas = mx.put_along_axis(
+ canvas,
+ active_idx[None, :],
+ canvas_active,
+ axis=-1,
+ )
+ frozen = mx.put_along_axis(
+ frozen,
+ active_idx[None, :],
+ mx.take_along_axis(frozen, active_idx[None, :], axis=-1)
+ | accept,
+ axis=-1,
+ )
+ argmax_full = mx.put_along_axis(
+ committed.astype(top1.dtype),
+ active_idx[None, :],
+ top1,
+ axis=-1,
+ )
+ else:
+ committed = mx.where(accept, proposal, canvas)
+ canvas = mx.where(
+ accept,
+ committed,
+ mx.random.randint(0, ctx.vocab_size, canvas.shape),
+ )
+ argmax_full = top1
+
+ sc_active = self_conditioner.soft_embeddings(probs, idx)
+ if self.config.monotone and active_idx is not None:
+ base = (
+ sc_embeddings
+ if sc_embeddings is not None
+ else mx.zeros(
+ (ctx.batch_size, ctx.canvas_length, sc_active.shape[-1]),
+ dtype=sc_active.dtype,
+ )
+ )
+ sc_embeddings = mx.put_along_axis(
+ base,
+ mx.broadcast_to(
+ active_idx[None, :, None],
+ (ctx.batch_size, active_idx.size, sc_active.shape[-1]),
+ ),
+ sc_active,
+ axis=1,
+ )
+ else:
+ sc_embeddings = sc_active
+
+ if self.config.monotone:
+ mx.eval(frozen, committed, entropy)
+ frozen_list = frozen[0].tolist()
+ active_positions = [i for i, value in enumerate(frozen_list) if not value]
+ mean_entropy = float(mx.mean(entropy).item())
+
+ if self.config.eos_early_stop:
+ eos_pos = None
+ committed_list = committed[0].tolist()
+ for i, (token_id, is_frozen) in enumerate(
+ zip(committed_list, frozen_list)
+ ):
+ if is_frozen and ctx.tokenizer.stopping_criteria(int(token_id)):
+ eos_pos = i
+ break
+ if eos_pos is not None and all(frozen_list[: eos_pos + 1]):
+ final_canvas = committed
+ emit_length = eos_pos + 1
+ break
+
+ if not active_positions:
+ final_canvas = committed
+ break
+ if len(active_positions) == ctx.canvas_length:
+ active_positions = list(range(ctx.canvas_length))
+ else:
+ mx.eval(argmax_full, entropy)
+ mean_entropy = float(mx.mean(entropy).item())
+
+ if prev_argmax is not None and not self.config.monotone:
+ if bool(mx.all(argmax_full == prev_argmax).item()):
+ stable_count += 1
+ else:
+ stable_count = 0
+ if (
+ stable_count >= stability_threshold
+ and mean_entropy < confidence_threshold
+ ):
+ final_canvas = argmax_full
+ break
+ prev_argmax = argmax_full
+
+ displayed_step = step + 1
+ should_show_unmasking = ctx.show_unmasking and (
+ displayed_step == 1
+ or displayed_step == ctx.max_denoising_steps
+ or displayed_step % ctx.unmasking_interval == 0
+ or flush
+ )
+ if should_show_unmasking:
+ if self.config.monotone:
+ yield self._draft(
+ ctx,
+ committed[0].tolist(),
+ frozen[0].tolist(),
+ displayed_step,
+ )
+ else:
+ reveal = [True] * ctx.canvas_length if flush else [False] * ctx.canvas_length
+ yield self._draft(ctx, argmax_full[0].tolist(), reveal, displayed_step)
+
+ if final_canvas is None:
+ final_canvas = committed if self.config.monotone else argmax_full
+
+ if self.config.repair and emit_length == ctx.canvas_length:
+ inputs_embeds = decoder.embed_tokens(final_canvas) * decoder.embed_scale
+ soft = (
+ sc_embeddings.astype(inputs_embeds.dtype)
+ if sc_embeddings is not None
+ else mx.zeros_like(inputs_embeds)
+ )
+ h = decoder.self_conditioning(inputs_embeds, soft)
+ cache_list = ctx.kv_cache or [None] * len(decoder.layers)
+ offset = (
+ cache_list[0].offset
+ if cache_list and getattr(cache_list[0], "keys", None) is not None
+ else 0
+ )
+ for layer, cache in zip(decoder.layers, cache_list):
+ h = layer(
+ h,
+ ctx.mask_mapping.get(layer.layer_type),
+ cache,
+ decoder=True,
+ offset=offset,
+ )
+ h = decoder.norm(h)
+ vals, idx = _topk_logits(decoder.embed_tokens.as_linear(h), self.config.topk)
+ final_canvas = (
+ mx.take_along_axis(idx, mx.argmax(vals, axis=-1)[..., None], axis=-1)
+ .squeeze(-1)
+ .astype(mx.int32)
+ )
+ steps += 1
+ work_tokens += ctx.canvas_length
+
+ yield DiffusionCanvasResult(
+ canvas=final_canvas[:, :emit_length],
+ canvas_tokens=ctx.canvas_length,
+ denoising_steps=steps,
+ work_tokens=work_tokens,
+ )
+
+ def _compact(
+ self,
+ ctx: DiffusionCanvasDenoiseContext,
+ ) -> Generator[DiffusionCanvasDraft | DiffusionCanvasResult, None, None]:
+ if ctx.decoder_attention_mask is not None:
+ raise ValueError(
+ "turbo_compact does not support padded or static decoder masks yet."
+ )
+
+ from ..models.diffusion_gemma.diffusion_turbo_runner import TurboCanvasRunner
+
+ decoder = ctx.model.model.decoder
+ self_conditioner = _TurboSelfConditioner(decoder)
+ cache0 = ctx.kv_cache[0]
+ prefix_offset = (
+ int(cache0.offset) if getattr(cache0, "keys", None) is not None else 0
+ )
+ runner = TurboCanvasRunner(
+ ctx.model,
+ ctx.kv_cache,
+ ctx.canvas_length,
+ prefix_offset + ctx.canvas_length + 8,
+ )
+ softcap = self._softcap(ctx)
+ entropy_bound = mx.array(
+ self._entropy_bound if self._entropy_bound is not None else ctx.entropy_bound,
+ dtype=mx.float32,
+ )
+ hidden_size = decoder.config.hidden_size
+ canvas_dev = mx.random.randint(0, ctx.vocab_size, (1, ctx.canvas_length)).astype(
+ mx.int32
+ )
+ committed_dev = canvas_dev
+ sc_full = mx.zeros((1, ctx.canvas_length, hidden_size), dtype=mx.bfloat16)
+ frozen = [False] * ctx.canvas_length
+ tail_dropped = [False] * ctx.canvas_length
+ newly_frozen: list[int] = []
+ steps = 0
+ work_tokens = 0
+ emit_length = ctx.canvas_length
+
+ if ctx.show_unmasking:
+ yield self._draft(
+ ctx,
+ canvas_dev[0].tolist(),
+ [False] * ctx.canvas_length,
+ 0,
+ )
+
+ for step in range(ctx.max_denoising_steps):
+ active = [i for i in range(ctx.canvas_length) if not frozen[i]]
+ if not active:
+ break
+ forward_positions = sorted(active + newly_frozen)
+ steps += 1
+ work_tokens += len(forward_positions)
+
+ forward_bucket = _bucket_size(len(forward_positions))
+ forward_bucketed = forward_positions + [forward_positions[-1]] * (
+ forward_bucket - len(forward_positions)
+ )
+ active_bucket = _bucket_size(len(active))
+ active_bucketed = active + [active[-1]] * (active_bucket - len(active))
+
+ rel = mx.array(forward_bucketed, dtype=mx.int32)
+ runner.set_forward_positions(rel)
+ h_forward = runner(
+ canvas_dev[:, rel],
+ prefix_offset + rel,
+ sc_full[:, rel, :],
+ )
+
+ pos_in_forward = {p: i for i, p in enumerate(forward_positions)}
+ active_idx = mx.array(
+ [pos_in_forward[p] for p in active_bucketed],
+ dtype=mx.int32,
+ )
+ raw_logits = decoder.embed_tokens.as_linear(h_forward[:, active_idx, :])
+ vals, idx = _topk_logits(raw_logits, self.config.topk)
+ probs, entropy = _topk_postprocess(
+ vals,
+ softcap,
+ _turbo_temperature(ctx, step),
+ )
+
+ real_count = len(active)
+ real_probs = probs[:, :real_count, :]
+ real_idx = idx[:, :real_count, :]
+ real_entropy = entropy[:, :real_count]
+ real_active_pos = mx.array(active, dtype=mx.int32)
+ proposal = self._sample_from_topk(
+ real_probs,
+ real_idx,
+ ctx.temperature,
+ ).astype(mx.int32)
+ flush = self.config.steps is not None and step >= self.config.steps - 1
+ accept = self._accept_mask(
+ probs=real_probs,
+ entropy=real_entropy,
+ step=step,
+ entropy_bound=entropy_bound,
+ quota=0,
+ flush=flush,
+ )
+
+ noise = mx.random.randint(0, ctx.vocab_size, accept.shape).astype(mx.int32)
+ canvas_dev = mx.put_along_axis(
+ canvas_dev,
+ real_active_pos[None, :],
+ mx.where(accept, proposal, noise),
+ axis=-1,
+ )
+ committed_dev = mx.put_along_axis(
+ committed_dev,
+ real_active_pos[None, :],
+ mx.where(
+ accept,
+ proposal,
+ mx.take_along_axis(
+ committed_dev,
+ real_active_pos[None, :],
+ axis=-1,
+ ),
+ ),
+ axis=-1,
+ )
+ sc_active = self_conditioner.soft_embeddings(real_probs, real_idx)
+ sc_full = mx.put_along_axis(
+ sc_full,
+ mx.broadcast_to(
+ real_active_pos[None, :, None],
+ (1, real_count, hidden_size),
+ ),
+ sc_active.astype(sc_full.dtype),
+ axis=1,
+ )
+
+ mx.eval(canvas_dev, committed_dev, sc_full)
+ accept_list = accept[0].tolist()
+ if not any(accept_list) and not flush:
+ best = int(mx.argmax(real_probs.max(axis=-1), axis=-1).item())
+ accept_list[best] = True
+ p = active[best]
+ token_id = int(proposal[0, best].item())
+ pos = mx.array([[p]], dtype=mx.int32)
+ token = mx.array([[token_id]], dtype=mx.int32)
+ canvas_dev = mx.put_along_axis(canvas_dev, pos, token, axis=-1)
+ committed_dev = mx.put_along_axis(committed_dev, pos, token, axis=-1)
+
+ newly_frozen = []
+ for i, accepted in enumerate(accept_list):
+ if accepted:
+ pos = active[i]
+ frozen[pos] = True
+ newly_frozen.append(pos)
+
+ if self.config.eos_early_stop and newly_frozen:
+ committed_now = committed_dev[0].tolist()
+ eos_pos = None
+ for i, token_id in enumerate(committed_now):
+ if frozen[i] and ctx.tokenizer.stopping_criteria(int(token_id)):
+ eos_pos = i
+ break
+ if eos_pos is not None:
+ if all(frozen[: eos_pos + 1]):
+ emit_length = eos_pos + 1
+ break
+ for i in range(eos_pos + 1, ctx.canvas_length):
+ if not frozen[i]:
+ frozen[i] = True
+ tail_dropped[i] = True
+ newly_frozen = [p for p in newly_frozen if p <= eos_pos]
+
+ displayed_step = step + 1
+ should_show_unmasking = ctx.show_unmasking and (
+ displayed_step == 1
+ or displayed_step == ctx.max_denoising_steps
+ or displayed_step % ctx.unmasking_interval == 0
+ or flush
+ )
+ if should_show_unmasking:
+ committed_view = committed_dev[0].tolist()
+ reveal = [
+ frozen[i] and not tail_dropped[i] for i in range(ctx.canvas_length)
+ ]
+ yield self._draft(ctx, committed_view, reveal, displayed_step)
+
+ committed = [int(token_id) for token_id in committed_dev[0].tolist()]
+
+ if self.config.repair and emit_length == ctx.canvas_length:
+ rel = mx.arange(ctx.canvas_length, dtype=mx.int32)
+ runner.set_forward_positions(rel)
+ h_forward = runner(
+ mx.array([committed], dtype=mx.int32),
+ prefix_offset + rel,
+ sc_full,
+ )
+ vals, idx = _topk_logits(
+ decoder.embed_tokens.as_linear(h_forward),
+ self.config.topk,
+ )
+ repaired = (
+ mx.take_along_axis(idx, mx.argmax(vals, axis=-1)[..., None], axis=-1)
+ .squeeze(-1)
+ .astype(mx.int32)
+ )
+ mx.eval(repaired)
+ committed = [int(token_id) for token_id in repaired[0].tolist()]
+ steps += 1
+ work_tokens += ctx.canvas_length
+
+ yield DiffusionCanvasResult(
+ canvas=mx.array([committed[:emit_length]], dtype=mx.int32),
+ canvas_tokens=ctx.canvas_length,
+ denoising_steps=steps,
+ work_tokens=work_tokens,
+ )
diff --git a/mlx_vlm/models/diffusion_gemma/diffusion_turbo_runner.py b/mlx_vlm/models/diffusion_gemma/diffusion_turbo_runner.py
new file mode 100644
index 0000000..2439411
--- /dev/null
+++ b/mlx_vlm/models/diffusion_gemma/diffusion_turbo_runner.py
@@ -0,0 +1,238 @@
+"""Active-set-compacted decoder runner for DiffusionGemma turbo diffusion."""
+
+from __future__ import annotations
+
+from typing import List, Optional
+
+import mlx.core as mx
+
+from .language import _cache_offset, _cache_state
+
+
+def _rope_tables(dims: int, base: float, positions: int):
+ inv_freq = base ** (-(mx.arange(0, dims, 2, dtype=mx.float32) / dims))
+ pos = mx.arange(positions, dtype=mx.float32)[:, None]
+ angles = pos * inv_freq[None, :]
+ return mx.cos(angles), mx.sin(angles)
+
+
+def _rope_tables_from_freqs(freqs: mx.array, positions: int):
+ pos = mx.arange(positions, dtype=mx.float32)[:, None]
+ angles = pos / freqs[None, :]
+ return mx.cos(angles), mx.sin(angles)
+
+
+def _apply_rope_gathered(x: mx.array, cos: mx.array, sin: mx.array):
+ half = x.shape[-1] // 2
+ x1 = x[..., :half].astype(mx.float32)
+ x2 = x[..., half:].astype(mx.float32)
+ c = cos[None, None, :, :]
+ s = sin[None, None, :, :]
+ out1 = x1 * c - x2 * s
+ out2 = x2 * c + x1 * s
+ return mx.concatenate([out1, out2], axis=-1).astype(x.dtype)
+
+
+class TurboCanvasRunner:
+ """Forward DiffusionGemma decoder layers over a compacted canvas subset."""
+
+ def __init__(self, model, kv_cache, canvas_length: int, max_position: int):
+ self.decoder = model.model.decoder
+ self.config = self.decoder.config
+ self.layers = self.decoder.layers
+ self.kv_cache = kv_cache
+
+ cfg = self.config
+ sliding_params = cfg.rope_parameters.get("sliding_attention", {})
+ full_params = cfg.rope_parameters.get("full_attention", {})
+ self.cos_s, self.sin_s = _rope_tables(
+ cfg.head_dim,
+ float(sliding_params.get("rope_theta", 10000.0)),
+ max_position,
+ )
+
+ global_dim = cfg.global_head_dim or cfg.head_dim
+ partial = float(full_params.get("partial_rotary_factor", 1.0))
+ rope_angles = int(partial * global_dim // 2)
+ self.full_rotated = 2 * rope_angles
+ factor = float(full_params.get("factor", 1.0))
+ base = float(full_params.get("rope_theta", 10000.0))
+ exponents = mx.arange(0, self.full_rotated, 2, dtype=mx.float32) / global_dim
+ freqs = factor * (base**exponents)
+ self.cos_f, self.sin_f = _rope_tables_from_freqs(freqs, max_position)
+ mx.eval(self.cos_s, self.sin_s, self.cos_f, self.sin_f)
+
+ self.prefix_k: List[Optional[mx.array]] = []
+ self.prefix_v: List[Optional[mx.array]] = []
+ window = max(cfg.sliding_window - 1, 0)
+ for layer, cache in zip(self.layers, kv_cache):
+ state = _cache_state(cache)
+ if state is None:
+ self.prefix_k.append(None)
+ self.prefix_v.append(None)
+ continue
+ keys, values = state
+ offset = _cache_offset(cache)
+ if layer.layer_type == "sliding_attention":
+ encoder_len = keys.shape[2]
+ if window and encoder_len > window and offset >= encoder_len:
+ keys = keys[:, :, -window:, :]
+ values = values[:, :, -window:, :]
+ self.prefix_k.append(keys)
+ self.prefix_v.append(values)
+
+ dtype = (
+ self.decoder.embed_tokens.scales.dtype
+ if hasattr(self.decoder.embed_tokens, "scales")
+ else mx.bfloat16
+ )
+ self.buf_k: List[mx.array] = []
+ self.buf_v: List[mx.array] = []
+ for layer in self.layers:
+ attn = layer.self_attn
+ shape = (1, attn.n_kv_heads, canvas_length, attn.head_dim)
+ self.buf_k.append(mx.zeros(shape, dtype=dtype))
+ self.buf_v.append(mx.zeros(shape, dtype=dtype))
+ self._scatter_idx = []
+
+ def set_forward_positions(self, canvas_rel_positions: mx.array):
+ self._scatter_idx = []
+ for layer in self.layers:
+ attn = layer.self_attn
+ idx = mx.broadcast_to(
+ canvas_rel_positions[None, None, :, None],
+ (1, attn.n_kv_heads, canvas_rel_positions.size, attn.head_dim),
+ )
+ self._scatter_idx.append(idx)
+
+ def __call__(
+ self,
+ tokens_f: mx.array,
+ positions_f: mx.array,
+ sc_f: Optional[mx.array],
+ ):
+ decoder = self.decoder
+ inputs_embeds = decoder.embed_tokens(tokens_f) * decoder.embed_scale
+ soft = (
+ sc_f.astype(inputs_embeds.dtype)
+ if sc_f is not None
+ else mx.zeros_like(inputs_embeds)
+ )
+ h = decoder.self_conditioning(inputs_embeds, soft)
+
+ batch_size, forward_count, _ = h.shape
+ cos_s = self.cos_s[positions_f]
+ sin_s = self.sin_s[positions_f]
+ cos_f = self.cos_f[positions_f]
+ sin_f = self.sin_f[positions_f]
+
+ for layer_index, layer in enumerate(self.layers):
+ attn = layer.self_attn
+ residual = h
+ h_norm = layer.input_layernorm(h)
+
+ q = attn.q_proj(h_norm).reshape(
+ batch_size,
+ forward_count,
+ attn.n_heads,
+ attn.head_dim,
+ )
+ q = attn.q_norm(q).transpose(0, 2, 1, 3)
+ k = attn.k_proj(h_norm).reshape(
+ batch_size,
+ forward_count,
+ attn.n_kv_heads,
+ attn.head_dim,
+ )
+ v_raw = (
+ attn.v_proj(h_norm).reshape(
+ batch_size,
+ forward_count,
+ attn.n_kv_heads,
+ attn.head_dim,
+ )
+ if attn.v_proj is not None
+ else k
+ )
+ k = attn.k_norm(k).transpose(0, 2, 1, 3)
+ v = attn.v_norm(v_raw).transpose(0, 2, 1, 3)
+
+ if attn.is_sliding:
+ q = _apply_rope_gathered(q, cos_s, sin_s)
+ k = _apply_rope_gathered(k, cos_s, sin_s)
+ elif self.full_rotated:
+ rotated_dims = self.full_rotated
+ half = attn.head_dim // 2
+ rotary_half = rotated_dims // 2
+
+ def proportional_rope(x):
+ left = x[..., :half]
+ right = x[..., half:]
+ rotary = mx.concatenate(
+ [left[..., :rotary_half], right[..., :rotary_half]],
+ axis=-1,
+ )
+ rotary = _apply_rope_gathered(rotary, cos_f, sin_f)
+ left = mx.concatenate(
+ [rotary[..., :rotary_half], left[..., rotary_half:]],
+ axis=-1,
+ )
+ right = mx.concatenate(
+ [rotary[..., rotary_half:], right[..., rotary_half:]],
+ axis=-1,
+ )
+ return mx.concatenate([left, right], axis=-1)
+
+ q = proportional_rope(q)
+ k = proportional_rope(k)
+
+ scatter_idx = self._scatter_idx[layer_index]
+ self.buf_k[layer_index] = mx.put_along_axis(
+ self.buf_k[layer_index],
+ scatter_idx,
+ k.astype(self.buf_k[layer_index].dtype),
+ axis=2,
+ )
+ self.buf_v[layer_index] = mx.put_along_axis(
+ self.buf_v[layer_index],
+ scatter_idx,
+ v.astype(self.buf_v[layer_index].dtype),
+ axis=2,
+ )
+
+ keys = self.buf_k[layer_index]
+ values = self.buf_v[layer_index]
+ if self.prefix_k[layer_index] is not None:
+ keys = mx.concatenate([self.prefix_k[layer_index], keys], axis=2)
+ values = mx.concatenate([self.prefix_v[layer_index], values], axis=2)
+
+ out = mx.fast.scaled_dot_product_attention(
+ q,
+ keys,
+ values,
+ scale=1.0,
+ mask=None,
+ )
+ out = out.transpose(0, 2, 1, 3).reshape(batch_size, forward_count, -1)
+ h_norm = attn.o_proj(out)
+
+ h_norm = layer.post_attention_layernorm(h_norm)
+ h = residual + h_norm
+
+ residual = h
+ h1 = layer.pre_feedforward_layernorm(h)
+ h1 = layer.mlp(h1)
+ h1 = layer.post_feedforward_layernorm_1(h1)
+
+ flat = residual.reshape(-1, residual.shape[-1])
+ top_k_indices, top_k_weights = layer.router(flat)
+ h2 = layer.pre_feedforward_layernorm_2(flat)
+ h2 = layer.experts(h2, top_k_indices, top_k_weights)
+ h2 = h2.reshape(residual.shape)
+ h2 = layer.post_feedforward_layernorm_2(h2)
+
+ h = layer.post_feedforward_layernorm(h1 + h2)
+ h = residual + h
+ h = h * layer.layer_scalar
+
+ return decoder.norm(h)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment