Created
June 17, 2026 19:44
-
-
Save lucasnewman/ee6d792db0b5f23f8ad343c7a564944c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/mlx_vlm/generate/diffusion.py b/mlx_vlm/generate/diffusion.py | |
| index ae56794..31e7ba4 100644 | |
| --- a/mlx_vlm/generate/diffusion.py | |
| +++ b/mlx_vlm/generate/diffusion.py | |
| @@ -4,6 +4,7 @@ import logging | |
| import os | |
| import shutil | |
| import time | |
| +from dataclasses import dataclass | |
| from typing import Any, Dict, Generator, List, Optional, Tuple | |
| import mlx.core as mx | |
| @@ -31,6 +32,47 @@ DEFAULT_DIFFUSION_MAX_DENOISING_STEPS = 48 | |
| DEFAULT_DIFFUSION_UNMASKING_WIDTH = 0 | |
| +@dataclass | |
| +class DiffusionCanvasDraft: | |
| + text: str | |
| + step: int | |
| + total_steps: int | |
| + | |
| + | |
| +@dataclass | |
| +class DiffusionCanvasResult: | |
| + canvas: mx.array | |
| + canvas_tokens: int | |
| + denoising_steps: int | |
| + work_tokens: int | |
| + | |
| + | |
| +@dataclass | |
| +class DiffusionCanvasDenoiseContext: | |
| + model: nn.Module | |
| + tokenizer: PreTrainedTokenizer | |
| + input_dtype: Any | |
| + batch_size: int | |
| + canvas_length: int | |
| + vocab_size: int | |
| + kv_cache: Any | |
| + decoder_attention_mask: Optional[mx.array] | |
| + mask_mapping: Any | |
| + max_denoising_steps: int | |
| + temperature: float | |
| + temperature_config: Dict[str, Any] | |
| + sampler: str | |
| + threshold: float | |
| + entropy_bound: float | |
| + stopping_config: Optional[Dict[str, Any]] | |
| + soft_embedding_weight: mx.array | |
| + compile_graph: bool | |
| + show_unmasking: bool | |
| + unmasking_interval: int | |
| + unmasking_width: int | |
| + skip_special_token_ids: Any | |
| + | |
| + | |
| def _diffusion_display_limit(requested_width: Optional[int] = None) -> Optional[int]: | |
| terminal_width = shutil.get_terminal_size((120, 20)).columns | |
| requested_width = requested_width or DEFAULT_DIFFUSION_UNMASKING_WIDTH | |
| @@ -527,6 +569,219 @@ def _decode_diffusion_masked_draft( | |
| return escape_carriage_returns(text) | |
| +def _standard_diffusion_canvas_denoiser( | |
| + ctx: DiffusionCanvasDenoiseContext, | |
| +) -> Generator[DiffusionCanvasDraft | DiffusionCanvasResult, None, None]: | |
| + current_canvas = _diffusion_initialize_canvas( | |
| + ctx.batch_size, | |
| + ctx.canvas_length, | |
| + ctx.vocab_size, | |
| + ctx.input_dtype, | |
| + ) | |
| + draft_reveal_mask = mx.zeros(current_canvas.shape, dtype=mx.bool_) | |
| + draft_canvas = current_canvas | |
| + accepted_canvas = current_canvas | |
| + argmax_canvas = current_canvas | |
| + self_conditioning_embeddings = None | |
| + decoder_logits_without_sc, decoder_logits_with_sc = ( | |
| + _make_diffusion_decoder_logits_fns( | |
| + ctx.model, | |
| + ctx.kv_cache, | |
| + ctx.mask_mapping, | |
| + compile_graph=ctx.compile_graph, | |
| + ) | |
| + ) | |
| + diffusion_history: List[mx.array] = [] | |
| + denoising_steps = 0 | |
| + compile_graph = ctx.compile_graph | |
| + | |
| + if ctx.show_unmasking: | |
| + draft_text = _decode_diffusion_masked_draft( | |
| + ctx.tokenizer, | |
| + [int(token_id) for token_id in draft_canvas[0].tolist()], | |
| + [False] * ctx.canvas_length, | |
| + ctx.skip_special_token_ids, | |
| + max_chars=ctx.unmasking_width, | |
| + ) | |
| + yield DiffusionCanvasDraft( | |
| + text=draft_text, | |
| + step=0, | |
| + total_steps=ctx.max_denoising_steps, | |
| + ) | |
| + | |
| + for cur_step in reversed(range(1, ctx.max_denoising_steps + 1)): | |
| + denoising_steps += 1 | |
| + try: | |
| + if self_conditioning_embeddings is None: | |
| + processed_logits = decoder_logits_without_sc(current_canvas) | |
| + else: | |
| + processed_logits = decoder_logits_with_sc( | |
| + current_canvas, | |
| + self_conditioning_embeddings, | |
| + ) | |
| + except Exception as exc: | |
| + if not compile_graph: | |
| + raise | |
| + logger.warning( | |
| + "Diffusion decoder compilation failed; falling back " | |
| + "to the eager path: %s", | |
| + exc, | |
| + ) | |
| + compile_graph = False | |
| + ctx.compile_graph = False | |
| + decoder_logits_without_sc, decoder_logits_with_sc = ( | |
| + _make_diffusion_decoder_logits_fns( | |
| + ctx.model, | |
| + ctx.kv_cache, | |
| + ctx.mask_mapping, | |
| + compile_graph=False, | |
| + ) | |
| + ) | |
| + if self_conditioning_embeddings is None: | |
| + processed_logits = decoder_logits_without_sc(current_canvas) | |
| + else: | |
| + processed_logits = decoder_logits_with_sc( | |
| + current_canvas, | |
| + self_conditioning_embeddings, | |
| + ) | |
| + schedule_temperature = _diffusion_linear_temperature( | |
| + cur_step, | |
| + ctx.max_denoising_steps, | |
| + ctx.temperature_config, | |
| + ) | |
| + if schedule_temperature is not None: | |
| + processed_logits = processed_logits / schedule_temperature | |
| + | |
| + argmax_canvas = mx.argmax(processed_logits, axis=-1).astype(ctx.input_dtype) | |
| + if cur_step == 1 and not ctx.show_unmasking: | |
| + break | |
| + | |
| + denoiser_canvas = ( | |
| + argmax_canvas | |
| + if ctx.temperature <= 0 | |
| + else _diffusion_sample_canvas( | |
| + processed_logits, | |
| + ctx.input_dtype, | |
| + ctx.temperature, | |
| + ) | |
| + ) | |
| + | |
| + if ctx.sampler == "entropy-bound": | |
| + if cur_step > 1: | |
| + token_entropy, next_self_conditioning_embeddings = ( | |
| + _diffusion_entropy_and_soft_embeddings( | |
| + processed_logits, | |
| + ctx.soft_embedding_weight, | |
| + ctx.model.model.decoder.embed_scale, | |
| + ) | |
| + ) | |
| + else: | |
| + token_entropy = _diffusion_token_entropy(processed_logits) | |
| + next_self_conditioning_embeddings = None | |
| + acceptance_mask = _diffusion_entropy_transfer_mask( | |
| + token_entropy, | |
| + ctx.entropy_bound, | |
| + ) | |
| + accepted_canvas = mx.where( | |
| + acceptance_mask, | |
| + denoiser_canvas, | |
| + current_canvas, | |
| + ) | |
| + current_canvas = mx.where( | |
| + acceptance_mask, | |
| + accepted_canvas, | |
| + _diffusion_initialize_canvas( | |
| + ctx.batch_size, | |
| + ctx.canvas_length, | |
| + ctx.vocab_size, | |
| + ctx.input_dtype, | |
| + ), | |
| + ) | |
| + draft_reveal_mask = acceptance_mask | |
| + draft_canvas = argmax_canvas | |
| + else: | |
| + next_self_conditioning_embeddings = None | |
| + unrevealed_mask = ~draft_reveal_mask | |
| + confidence = _diffusion_token_probability( | |
| + processed_logits, | |
| + denoiser_canvas, | |
| + ) | |
| + acceptance_mask = _diffusion_confidence_transfer_mask( | |
| + confidence, | |
| + unrevealed_mask, | |
| + ctx.threshold, | |
| + force_all=cur_step == 1, | |
| + ) | |
| + accepted_canvas = mx.where( | |
| + acceptance_mask, | |
| + denoiser_canvas, | |
| + draft_canvas, | |
| + ) | |
| + current_canvas = mx.where( | |
| + draft_reveal_mask | acceptance_mask, | |
| + accepted_canvas, | |
| + _diffusion_initialize_canvas( | |
| + ctx.batch_size, | |
| + ctx.canvas_length, | |
| + ctx.vocab_size, | |
| + ctx.input_dtype, | |
| + ), | |
| + ) | |
| + draft_reveal_mask = draft_reveal_mask | acceptance_mask | |
| + draft_canvas = mx.where(acceptance_mask, accepted_canvas, draft_canvas) | |
| + | |
| + displayed_step = ctx.max_denoising_steps - cur_step + 1 | |
| + should_show_unmasking = ctx.show_unmasking and ( | |
| + displayed_step == 1 | |
| + or cur_step == 1 | |
| + or displayed_step % ctx.unmasking_interval == 0 | |
| + ) | |
| + if should_show_unmasking: | |
| + mx.eval(draft_canvas, draft_reveal_mask) | |
| + draft_text = _decode_diffusion_masked_draft( | |
| + ctx.tokenizer, | |
| + [int(token_id) for token_id in draft_canvas[0].tolist()], | |
| + [bool(v) for v in draft_reveal_mask[0].tolist()], | |
| + ctx.skip_special_token_ids, | |
| + max_chars=ctx.unmasking_width, | |
| + ) | |
| + yield DiffusionCanvasDraft( | |
| + text=draft_text, | |
| + step=displayed_step, | |
| + total_steps=ctx.max_denoising_steps, | |
| + ) | |
| + | |
| + if ctx.sampler == "confidence-threshold" and bool( | |
| + mx.all(draft_reveal_mask).item() | |
| + ): | |
| + accepted_canvas = draft_canvas | |
| + break | |
| + | |
| + if _diffusion_stable_and_confident( | |
| + argmax_canvas, | |
| + processed_logits, | |
| + diffusion_history, | |
| + ctx.stopping_config, | |
| + ): | |
| + break | |
| + | |
| + if cur_step > 1: | |
| + if next_self_conditioning_embeddings is None: | |
| + next_self_conditioning_embeddings = _diffusion_soft_embeddings( | |
| + processed_logits, | |
| + ctx.soft_embedding_weight, | |
| + ctx.model.model.decoder.embed_scale, | |
| + ) | |
| + self_conditioning_embeddings = next_self_conditioning_embeddings | |
| + | |
| + yield DiffusionCanvasResult( | |
| + canvas=argmax_canvas, | |
| + canvas_tokens=ctx.canvas_length, | |
| + denoising_steps=denoising_steps, | |
| + work_tokens=ctx.canvas_length * denoising_steps, | |
| + ) | |
| + | |
| + | |
| def stream_diffusion_generate( | |
| model: nn.Module, | |
| processor: PreTrainedTokenizer, | |
| @@ -551,6 +806,8 @@ def stream_diffusion_generate( | |
| diffusion_unmasking_width: int = DEFAULT_DIFFUSION_UNMASKING_WIDTH, | |
| mm_token_type_ids: Optional[mx.array] = None, | |
| prefill_step_size: Optional[int] = None, | |
| + diffusion_canvas_denoiser: Optional[Any] = None, | |
| + diffusion_repeat_guard: Optional[int] = None, | |
| ) -> Generator[GenerationResult, None, None]: | |
| if input_ids.shape[0] != 1: | |
| raise ValueError( | |
| @@ -604,6 +861,9 @@ def stream_diffusion_generate( | |
| prefill_step_size = int(prefill_step_size) | |
| if prefill_step_size <= 0: | |
| raise ValueError("prefill_step_size must be a positive integer.") | |
| + repeat_guard = int(diffusion_repeat_guard or 0) | |
| + if repeat_guard < 0: | |
| + raise ValueError("diffusion_repeat_guard must be non-negative.") | |
| sampler_config = _diffusion_config_dict( | |
| generation_config.get("sampler_config", None) | |
| @@ -690,6 +950,9 @@ def stream_diffusion_generate( | |
| current_canvas = None | |
| stopped = False | |
| stop_reason = "length" | |
| + prev_emit_id = None | |
| + repeat_run = 0 | |
| + canvas_denoiser = diffusion_canvas_denoiser or _standard_diffusion_canvas_denoiser | |
| def make_result( | |
| text: str, | |
| @@ -778,227 +1041,62 @@ def stream_diffusion_generate( | |
| if decoder_attention_mask is not None | |
| else None | |
| ) | |
| - current_canvas = _diffusion_initialize_canvas( | |
| - batch_size, | |
| - canvas_length, | |
| - vocab_size, | |
| - input_ids.dtype, | |
| - ) | |
| - draft_reveal_mask = mx.zeros(current_canvas.shape, dtype=mx.bool_) | |
| - draft_canvas = current_canvas | |
| - accepted_canvas = current_canvas | |
| - argmax_canvas = current_canvas | |
| - self_conditioning_embeddings = None | |
| mask_mapping = model.model.decoder._make_decoder_masks( | |
| - current_canvas[..., None], | |
| + mx.zeros((batch_size, canvas_length, 1), dtype=input_ids.dtype), | |
| kv_cache, | |
| current_decoder_attention_mask, | |
| ) | |
| - decoder_logits_without_sc, decoder_logits_with_sc = ( | |
| - _make_diffusion_decoder_logits_fns( | |
| - model, | |
| - kv_cache, | |
| - mask_mapping, | |
| - compile_graph=diffusion_compile, | |
| - ) | |
| - ) | |
| - diffusion_history: List[mx.array] = [] | |
| - denoising_steps_this_canvas = 0 | |
| - | |
| - if diffusion_show_unmasking: | |
| - draft_text = _decode_diffusion_masked_draft( | |
| - tokenizer, | |
| - [int(token_id) for token_id in draft_canvas[0].tolist()], | |
| - [False] * canvas_length, | |
| - skip_special_token_ids, | |
| - max_chars=diffusion_unmasking_width, | |
| - ) | |
| - yield make_result( | |
| - "", | |
| - is_draft=True, | |
| - draft_text=draft_text, | |
| - diffusion_step=0, | |
| - diffusion_total_steps=max_denoising_steps, | |
| - diffusion_canvas_index=canvas_index, | |
| - ) | |
| - | |
| - for cur_step in reversed(range(1, max_denoising_steps + 1)): | |
| - denoising_steps_this_canvas += 1 | |
| - try: | |
| - if self_conditioning_embeddings is None: | |
| - processed_logits = decoder_logits_without_sc(current_canvas) | |
| - else: | |
| - processed_logits = decoder_logits_with_sc( | |
| - current_canvas, | |
| - self_conditioning_embeddings, | |
| - ) | |
| - except Exception as exc: | |
| - if not diffusion_compile: | |
| - raise | |
| - logger.warning( | |
| - "Diffusion decoder compilation failed; falling back " | |
| - "to the eager path: %s", | |
| - exc, | |
| - ) | |
| - diffusion_compile = False | |
| - decoder_logits_without_sc, decoder_logits_with_sc = ( | |
| - _make_diffusion_decoder_logits_fns( | |
| - model, | |
| - kv_cache, | |
| - mask_mapping, | |
| - compile_graph=False, | |
| - ) | |
| - ) | |
| - if self_conditioning_embeddings is None: | |
| - processed_logits = decoder_logits_without_sc(current_canvas) | |
| - else: | |
| - processed_logits = decoder_logits_with_sc( | |
| - current_canvas, | |
| - self_conditioning_embeddings, | |
| - ) | |
| - schedule_temperature = _diffusion_linear_temperature( | |
| - cur_step, | |
| - max_denoising_steps, | |
| - temperature_config, | |
| - ) | |
| - if schedule_temperature is not None: | |
| - processed_logits = processed_logits / schedule_temperature | |
| - | |
| - argmax_canvas = mx.argmax(processed_logits, axis=-1).astype( | |
| - input_ids.dtype | |
| - ) | |
| - if cur_step == 1 and not diffusion_show_unmasking: | |
| - break | |
| - | |
| - denoiser_canvas = ( | |
| - argmax_canvas | |
| - if temperature <= 0 | |
| - else _diffusion_sample_canvas( | |
| - processed_logits, | |
| - input_ids.dtype, | |
| - temperature, | |
| - ) | |
| - ) | |
| - if diffusion_sampler == "entropy-bound": | |
| - if cur_step > 1: | |
| - token_entropy, next_self_conditioning_embeddings = ( | |
| - _diffusion_entropy_and_soft_embeddings( | |
| - processed_logits, | |
| - soft_embedding_weight, | |
| - model.model.decoder.embed_scale, | |
| - ) | |
| - ) | |
| - else: | |
| - token_entropy = _diffusion_token_entropy(processed_logits) | |
| - next_self_conditioning_embeddings = None | |
| - acceptance_mask = _diffusion_entropy_transfer_mask( | |
| - token_entropy, | |
| - entropy_bound, | |
| - ) | |
| - accepted_canvas = mx.where( | |
| - acceptance_mask, | |
| - denoiser_canvas, | |
| - current_canvas, | |
| - ) | |
| - current_canvas = mx.where( | |
| - acceptance_mask, | |
| - accepted_canvas, | |
| - _diffusion_initialize_canvas( | |
| - batch_size, | |
| - canvas_length, | |
| - vocab_size, | |
| - input_ids.dtype, | |
| - ), | |
| - ) | |
| - draft_reveal_mask = acceptance_mask | |
| - draft_canvas = argmax_canvas | |
| - else: | |
| - next_self_conditioning_embeddings = None | |
| - unrevealed_mask = ~draft_reveal_mask | |
| - confidence = _diffusion_token_probability( | |
| - processed_logits, | |
| - denoiser_canvas, | |
| - ) | |
| - acceptance_mask = _diffusion_confidence_transfer_mask( | |
| - confidence, | |
| - unrevealed_mask, | |
| - diffusion_threshold, | |
| - force_all=cur_step == 1, | |
| - ) | |
| - accepted_canvas = mx.where( | |
| - acceptance_mask, | |
| - denoiser_canvas, | |
| - draft_canvas, | |
| - ) | |
| - current_canvas = mx.where( | |
| - draft_reveal_mask | acceptance_mask, | |
| - accepted_canvas, | |
| - _diffusion_initialize_canvas( | |
| - batch_size, | |
| - canvas_length, | |
| - vocab_size, | |
| - input_ids.dtype, | |
| - ), | |
| - ) | |
| - draft_reveal_mask = draft_reveal_mask | acceptance_mask | |
| - draft_canvas = mx.where( | |
| - acceptance_mask, accepted_canvas, draft_canvas | |
| - ) | |
| + ctx = DiffusionCanvasDenoiseContext( | |
| + model=model, | |
| + tokenizer=tokenizer, | |
| + input_dtype=input_ids.dtype, | |
| + batch_size=batch_size, | |
| + canvas_length=canvas_length, | |
| + vocab_size=vocab_size, | |
| + kv_cache=kv_cache, | |
| + decoder_attention_mask=current_decoder_attention_mask, | |
| + mask_mapping=mask_mapping, | |
| + max_denoising_steps=max_denoising_steps, | |
| + temperature=temperature, | |
| + temperature_config=temperature_config, | |
| + sampler=diffusion_sampler, | |
| + threshold=diffusion_threshold, | |
| + entropy_bound=entropy_bound, | |
| + stopping_config=diffusion_stopping_config, | |
| + soft_embedding_weight=soft_embedding_weight, | |
| + compile_graph=diffusion_compile, | |
| + show_unmasking=diffusion_show_unmasking, | |
| + unmasking_interval=diffusion_unmasking_interval, | |
| + unmasking_width=diffusion_unmasking_width, | |
| + skip_special_token_ids=skip_special_token_ids, | |
| + ) | |
| - displayed_step = max_denoising_steps - cur_step + 1 | |
| - should_show_unmasking = diffusion_show_unmasking and ( | |
| - displayed_step == 1 | |
| - or cur_step == 1 | |
| - or displayed_step % diffusion_unmasking_interval == 0 | |
| - ) | |
| - if should_show_unmasking: | |
| - mx.eval(draft_canvas, draft_reveal_mask) | |
| - draft_text = _decode_diffusion_masked_draft( | |
| - tokenizer, | |
| - [int(token_id) for token_id in draft_canvas[0].tolist()], | |
| - [bool(v) for v in draft_reveal_mask[0].tolist()], | |
| - skip_special_token_ids, | |
| - max_chars=diffusion_unmasking_width, | |
| - ) | |
| + canvas_result = None | |
| + for event in canvas_denoiser(ctx): | |
| + if isinstance(event, DiffusionCanvasDraft): | |
| yield make_result( | |
| "", | |
| is_draft=True, | |
| - draft_text=draft_text, | |
| - diffusion_step=displayed_step, | |
| - diffusion_total_steps=max_denoising_steps, | |
| + draft_text=event.text, | |
| + diffusion_step=event.step, | |
| + diffusion_total_steps=event.total_steps, | |
| diffusion_canvas_index=canvas_index, | |
| ) | |
| + continue | |
| + canvas_result = event | |
| - if diffusion_sampler == "confidence-threshold" and bool( | |
| - mx.all(draft_reveal_mask).item() | |
| - ): | |
| - accepted_canvas = draft_canvas | |
| - break | |
| - | |
| - if _diffusion_stable_and_confident( | |
| - argmax_canvas, | |
| - processed_logits, | |
| - diffusion_history, | |
| - diffusion_stopping_config, | |
| - ): | |
| - break | |
| + diffusion_compile = ctx.compile_graph | |
| + if canvas_result is None: | |
| + raise RuntimeError("Diffusion canvas denoiser did not return a canvas.") | |
| - if cur_step > 1: | |
| - if next_self_conditioning_embeddings is None: | |
| - next_self_conditioning_embeddings = _diffusion_soft_embeddings( | |
| - processed_logits, | |
| - soft_embedding_weight, | |
| - model.model.decoder.embed_scale, | |
| - ) | |
| - self_conditioning_embeddings = next_self_conditioning_embeddings | |
| - | |
| - current_canvas = argmax_canvas | |
| - diffusion_canvas_tokens += canvas_length | |
| - diffusion_denoising_steps += denoising_steps_this_canvas | |
| - diffusion_work_tokens += canvas_length * denoising_steps_this_canvas | |
| + current_canvas = canvas_result.canvas | |
| + diffusion_canvas_tokens += canvas_result.canvas_tokens | |
| + diffusion_denoising_steps += canvas_result.denoising_steps | |
| + diffusion_work_tokens += canvas_result.work_tokens | |
| mx.eval(current_canvas) | |
| + committed_length = current_canvas.shape[-1] | |
| for token_id in current_canvas[0].tolist(): | |
| last_token = int(token_id) | |
| generated_tokens += 1 | |
| @@ -1008,6 +1106,14 @@ def stream_diffusion_generate( | |
| stop_reason = "stop" | |
| break | |
| + if repeat_guard: | |
| + repeat_run = repeat_run + 1 if last_token == prev_emit_id else 1 | |
| + prev_emit_id = last_token | |
| + if repeat_run >= repeat_guard: | |
| + stopped = True | |
| + stop_reason = "repetition" | |
| + break | |
| + | |
| detokenizer.add_token( | |
| last_token, skip_special_token_ids=skip_special_token_ids | |
| ) | |
| @@ -1034,14 +1140,14 @@ def stream_diffusion_generate( | |
| if use_static_cache: | |
| decoder_attention_mask[ | |
| - :, cached_sequence_length : cached_sequence_length + canvas_length | |
| + :, cached_sequence_length : cached_sequence_length + committed_length | |
| ] = True | |
| - cached_sequence_length += canvas_length | |
| + cached_sequence_length += committed_length | |
| elif decoder_attention_mask is not None: | |
| decoder_attention_mask = mx.concatenate( | |
| [ | |
| decoder_attention_mask, | |
| - mx.ones((batch_size, canvas_length), dtype=mx.bool_), | |
| + mx.ones((batch_size, committed_length), dtype=mx.bool_), | |
| ], | |
| axis=-1, | |
| ) | |
| @@ -1077,6 +1183,15 @@ def stream_diffusion_generate_from_kwargs( | |
| diffusion_unmasking_width = kwargs.pop( | |
| "diffusion_unmasking_width", DEFAULT_DIFFUSION_UNMASKING_WIDTH | |
| ) | |
| + diffusion_turbo = kwargs.pop("diffusion_turbo", False) | |
| + diffusion_canvas_denoiser = None | |
| + diffusion_repeat_guard = None | |
| + if diffusion_turbo: | |
| + from .diffusion_turbo import make_diffusion_turbo_denoiser_from_kwargs | |
| + | |
| + diffusion_canvas_denoiser, diffusion_repeat_guard = ( | |
| + make_diffusion_turbo_denoiser_from_kwargs(kwargs) | |
| + ) | |
| mm_token_type_ids = kwargs.pop("mm_token_type_ids", None) | |
| prefill_step_size = kwargs.pop("prefill_step_size", None) | |
| with wired_limit(model, [generation_stream]): | |
| @@ -1103,5 +1218,7 @@ def stream_diffusion_generate_from_kwargs( | |
| diffusion_unmasking_width=diffusion_unmasking_width, | |
| mm_token_type_ids=mm_token_type_ids, | |
| prefill_step_size=prefill_step_size, | |
| + diffusion_canvas_denoiser=diffusion_canvas_denoiser, | |
| + diffusion_repeat_guard=diffusion_repeat_guard, | |
| ) | |
| mx.clear_cache() | |
| diff --git a/mlx_vlm/generate/diffusion_turbo.py b/mlx_vlm/generate/diffusion_turbo.py | |
| new file mode 100644 | |
| index 0000000..a31ab1a | |
| --- /dev/null | |
| +++ b/mlx_vlm/generate/diffusion_turbo.py | |
| @@ -0,0 +1,744 @@ | |
| +"""Turbo per-canvas denoisers for block-diffusion generation. | |
| + | |
| +The shared stream lifecycle stays in ``diffusion.py``: prompt prefill, decoder | |
| +mask construction, cache updates, token emission, draft emission, and result | |
| +accounting. This module only supplies optional denoising policies that trade | |
| +fidelity for throughput. | |
| +""" | |
| + | |
| +from __future__ import annotations | |
| + | |
| +from dataclasses import dataclass | |
| +from typing import Any, Generator, Optional | |
| + | |
| +import mlx.core as mx | |
| + | |
| +from .diffusion import ( | |
| + DiffusionCanvasDenoiseContext, | |
| + DiffusionCanvasDraft, | |
| + DiffusionCanvasResult, | |
| + _decode_diffusion_masked_draft, | |
| + _diffusion_initialize_canvas, | |
| +) | |
| + | |
| +DEFAULT_TURBO_TOPK = 64 | |
| +_SHAPE_BUCKETS = (16, 32, 48, 64, 96, 128, 160, 192, 224, 256) | |
| + | |
| + | |
| +@dataclass | |
| +class TurboDiffusionConfig: | |
| + topk: int = DEFAULT_TURBO_TOPK | |
| + monotone: bool = True | |
| + eos_early_stop: bool = True | |
| + quota: float = 0.0 | |
| + steps: Optional[int] = None | |
| + accept: str = "entropy-bound" | |
| + threshold: Any = 0.9 | |
| + window: int = 0 | |
| + repair: bool = False | |
| + compact: bool = False | |
| + repeat_guard: int = 16 | |
| + entropy_bound: Optional[float] = None | |
| + | |
| + | |
| +def make_diffusion_turbo_denoiser_from_kwargs(kwargs: dict): | |
| + no_monotone = bool(kwargs.pop("turbo_no_monotone", False)) | |
| + config = TurboDiffusionConfig( | |
| + topk=int(kwargs.pop("turbo_topk", DEFAULT_TURBO_TOPK)), | |
| + monotone=bool(kwargs.pop("turbo_monotone", not no_monotone)), | |
| + eos_early_stop=bool(kwargs.pop("turbo_eos_early_stop", True)), | |
| + quota=float(kwargs.pop("turbo_quota", 0.0)), | |
| + steps=kwargs.pop("turbo_steps", None), | |
| + accept=kwargs.pop("turbo_accept", "entropy-bound"), | |
| + threshold=kwargs.pop("turbo_threshold", 0.9), | |
| + window=int(kwargs.pop("turbo_window", 0)), | |
| + repair=bool(kwargs.pop("turbo_repair", False)), | |
| + compact=bool(kwargs.pop("turbo_compact", False)), | |
| + repeat_guard=int(kwargs.pop("turbo_repeat_guard", 16) or 0), | |
| + entropy_bound=kwargs.pop("entropy_bound", None), | |
| + ) | |
| + if config.topk <= 0: | |
| + raise ValueError("turbo_topk must be a positive integer.") | |
| + if config.steps is not None: | |
| + config.steps = int(config.steps) | |
| + if config.steps <= 0: | |
| + raise ValueError("turbo_steps must be a positive integer.") | |
| + if config.accept not in ("entropy-bound", "confidence"): | |
| + raise ValueError(f"Unsupported turbo_accept: {config.accept!r}.") | |
| + if config.repeat_guard < 0: | |
| + raise ValueError("turbo_repeat_guard must be non-negative.") | |
| + if config.entropy_bound is not None: | |
| + config.entropy_bound = float(config.entropy_bound) | |
| + return TurboDiffusionDenoiser(config), config.repeat_guard | |
| + | |
| + | |
| +def _bucket_size(n: int) -> int: | |
| + for bucket in _SHAPE_BUCKETS: | |
| + if n <= bucket: | |
| + return bucket | |
| + return _SHAPE_BUCKETS[-1] | |
| + | |
| + | |
| +def _turbo_temperature(ctx: DiffusionCanvasDenoiseContext, step: int) -> mx.array: | |
| + t_min = float(ctx.temperature_config.get("t_min", 0.4)) | |
| + t_max = float(ctx.temperature_config.get("t_max", 0.8)) | |
| + frac = 1.0 - step / max(ctx.max_denoising_steps - 1, 1) | |
| + return mx.array(t_min + (t_max - t_min) * frac, dtype=mx.float32) | |
| + | |
| + | |
| +def _topk_logits(raw_logits: mx.array, k: int, chunk: int = 64): | |
| + *lead, vocab_size = raw_logits.shape | |
| + k = min(int(k), vocab_size) | |
| + if k >= vocab_size or vocab_size % chunk != 0 or k > (vocab_size // chunk): | |
| + idx = mx.argpartition(raw_logits, kth=-k, axis=-1)[..., -k:] | |
| + vals = mx.take_along_axis(raw_logits, idx, axis=-1) | |
| + return vals, idx | |
| + | |
| + num_chunks = vocab_size // chunk | |
| + chunked = raw_logits.reshape(*lead, num_chunks, chunk) | |
| + chunk_max = chunked.max(axis=-1) | |
| + chunk_idx = mx.argpartition(chunk_max, kth=-k, axis=-1)[..., -k:] | |
| + sub = mx.take_along_axis(chunked, mx.expand_dims(chunk_idx, -1), axis=-2) | |
| + sub = sub.reshape(*lead, k * chunk) | |
| + sub_idx = mx.argpartition(sub, kth=-k, axis=-1)[..., -k:] | |
| + vals = mx.take_along_axis(sub, sub_idx, axis=-1) | |
| + selected_chunk = mx.take_along_axis(chunk_idx, sub_idx // chunk, axis=-1) | |
| + idx = selected_chunk * chunk + (sub_idx % chunk) | |
| + return vals, idx | |
| + | |
| + | |
| +def _topk_postprocess(vals: mx.array, softcap: mx.array, temp: mx.array): | |
| + capped = mx.tanh(vals.astype(mx.float32) / softcap) * softcap | |
| + scaled = capped / temp | |
| + log_probs = scaled - mx.logsumexp(scaled, axis=-1, keepdims=True) | |
| + probs = mx.exp(log_probs) | |
| + entropy = -mx.sum(probs * log_probs, axis=-1) | |
| + return probs, entropy | |
| + | |
| + | |
| +def _entropy_transfer_mask( | |
| + entropy: mx.array, | |
| + entropy_bound: mx.array, | |
| + quota: int = 0, | |
| +) -> mx.array: | |
| + sorted_indices = mx.argsort(entropy, axis=-1) | |
| + sorted_entropy = mx.take_along_axis(entropy, sorted_indices, axis=-1) | |
| + cumulative_entropy = mx.cumsum(sorted_entropy, axis=-1) | |
| + cumulative_maximum = mx.cummax(sorted_entropy, axis=-1) | |
| + sorted_mask = (cumulative_entropy - cumulative_maximum) <= entropy_bound | |
| + if quota > 0: | |
| + rank = mx.arange(entropy.shape[-1])[None, :] | |
| + sorted_mask = sorted_mask | (rank < quota) | |
| + mask = mx.zeros_like(sorted_mask) | |
| + return mx.put_along_axis(mask, sorted_indices, sorted_mask, axis=-1) | |
| + | |
| + | |
| +class _TurboSelfConditioner: | |
| + def __init__(self, decoder): | |
| + self.embed_tokens = decoder.embed_tokens | |
| + self.embed_scale = decoder.embed_scale | |
| + | |
| + def soft_embeddings(self, probs: mx.array, idx: mx.array) -> mx.array: | |
| + rows = self.embed_tokens(idx) | |
| + soft = (probs[..., None].astype(rows.dtype) * rows).sum(axis=-2) | |
| + return soft * self.embed_scale | |
| + | |
| + | |
| +class TurboDiffusionDenoiser: | |
| + def __init__(self, config: TurboDiffusionConfig): | |
| + self.config = config | |
| + | |
| + def __call__( | |
| + self, | |
| + ctx: DiffusionCanvasDenoiseContext, | |
| + ) -> Generator[DiffusionCanvasDraft | DiffusionCanvasResult, None, None]: | |
| + if ctx.batch_size != 1: | |
| + raise ValueError("Turbo diffusion generation only supports batch size 1.") | |
| + if self.config.compact: | |
| + yield from self._compact(ctx) | |
| + return | |
| + yield from self._full_canvas(ctx) | |
| + | |
| + @property | |
| + def _entropy_bound(self): | |
| + return self.config.entropy_bound | |
| + | |
| + def _softcap(self, ctx: DiffusionCanvasDenoiseContext) -> mx.array: | |
| + text_config = ctx.model.config.text_config | |
| + softcap = ( | |
| + getattr(ctx.model, "final_logit_softcapping", None) | |
| + or getattr(text_config, "final_logit_softcapping", 30.0) | |
| + ) | |
| + return mx.array(float(softcap), dtype=mx.float32) | |
| + | |
| + def _sample_from_topk(self, probs: mx.array, idx: mx.array, temperature: float): | |
| + if temperature <= 0: | |
| + selected = mx.argmax(probs, axis=-1) | |
| + else: | |
| + selected = mx.random.categorical( | |
| + mx.log(probs + 1e-20) / max(temperature, 1e-5) | |
| + ) | |
| + return mx.take_along_axis(idx, selected[..., None], axis=-1).squeeze(-1) | |
| + | |
| + def _accept_mask( | |
| + self, | |
| + *, | |
| + probs: mx.array, | |
| + entropy: mx.array, | |
| + step: int, | |
| + entropy_bound: mx.array, | |
| + quota: int, | |
| + flush: bool, | |
| + active_positions=None, | |
| + canvas_length: Optional[int] = None, | |
| + ) -> mx.array: | |
| + if flush: | |
| + return mx.ones(entropy.shape, dtype=mx.bool_) | |
| + if self.config.accept == "confidence": | |
| + if isinstance(self.config.threshold, (tuple, list)): | |
| + hi, lo, switch_step = self.config.threshold | |
| + threshold = hi if step < int(switch_step) else lo | |
| + else: | |
| + threshold = self.config.threshold | |
| + accept = probs.max(axis=-1) >= float(threshold) | |
| + if self.config.window > 0 and active_positions is not None: | |
| + if canvas_length is None or len(active_positions) == canvas_length: | |
| + frontier_limit = self.config.window | |
| + pos_of_active = mx.arange(canvas_length)[None, :] | |
| + else: | |
| + frontier_limit = active_positions[0] + self.config.window | |
| + pos_of_active = mx.array(active_positions, dtype=mx.int32)[None, :] | |
| + accept = accept & (pos_of_active < frontier_limit) | |
| + if not bool(mx.any(accept).item()): | |
| + best = mx.argmax(probs.max(axis=-1), axis=-1) | |
| + accept = mx.arange(probs.shape[-2])[None, :] == best[:, None] | |
| + return accept | |
| + return _entropy_transfer_mask(entropy, entropy_bound, quota) | |
| + | |
| + def _draft( | |
| + self, | |
| + ctx: DiffusionCanvasDenoiseContext, | |
| + token_ids, | |
| + reveal_mask, | |
| + step: int, | |
| + ) -> DiffusionCanvasDraft: | |
| + return DiffusionCanvasDraft( | |
| + text=_decode_diffusion_masked_draft( | |
| + ctx.tokenizer, | |
| + [int(token_id) for token_id in token_ids], | |
| + [bool(v) for v in reveal_mask], | |
| + ctx.skip_special_token_ids, | |
| + max_chars=ctx.unmasking_width, | |
| + ), | |
| + step=step, | |
| + total_steps=self.config.steps or ctx.max_denoising_steps, | |
| + ) | |
| + | |
| + def _full_canvas( | |
| + self, | |
| + ctx: DiffusionCanvasDenoiseContext, | |
| + ) -> Generator[DiffusionCanvasDraft | DiffusionCanvasResult, None, None]: | |
| + decoder = ctx.model.model.decoder | |
| + self_conditioner = _TurboSelfConditioner(decoder) | |
| + softcap = self._softcap(ctx) | |
| + entropy_bound = mx.array( | |
| + self._entropy_bound if self._entropy_bound is not None else ctx.entropy_bound, | |
| + dtype=mx.float32, | |
| + ) | |
| + stability_threshold = int( | |
| + (ctx.stopping_config or {}).get("stability_threshold", 1) | |
| + ) | |
| + confidence_threshold = float( | |
| + (ctx.stopping_config or {}).get("confidence_threshold", 0.005) | |
| + ) | |
| + | |
| + canvas = _diffusion_initialize_canvas( | |
| + ctx.batch_size, | |
| + ctx.canvas_length, | |
| + ctx.vocab_size, | |
| + mx.int32, | |
| + ) | |
| + frozen = mx.zeros((ctx.batch_size, ctx.canvas_length), dtype=mx.bool_) | |
| + committed = canvas | |
| + sc_embeddings = None | |
| + prev_argmax = None | |
| + stable_count = 0 | |
| + final_canvas = None | |
| + emit_length = ctx.canvas_length | |
| + active_positions = list(range(ctx.canvas_length)) | |
| + steps = 0 | |
| + work_tokens = 0 | |
| + | |
| + if ctx.show_unmasking: | |
| + yield self._draft( | |
| + ctx, | |
| + canvas[0].tolist(), | |
| + [False] * ctx.canvas_length, | |
| + 0, | |
| + ) | |
| + | |
| + for step in range(ctx.max_denoising_steps): | |
| + steps += 1 | |
| + inputs_embeds = decoder.embed_tokens(canvas) * decoder.embed_scale | |
| + soft = ( | |
| + sc_embeddings.astype(inputs_embeds.dtype) | |
| + if sc_embeddings is not None | |
| + else mx.zeros_like(inputs_embeds) | |
| + ) | |
| + h = decoder.self_conditioning(inputs_embeds, soft) | |
| + cache_list = ctx.kv_cache or [None] * len(decoder.layers) | |
| + offset = ( | |
| + cache_list[0].offset | |
| + if cache_list and getattr(cache_list[0], "keys", None) is not None | |
| + else 0 | |
| + ) | |
| + for layer, cache in zip(decoder.layers, cache_list): | |
| + h = layer( | |
| + h, | |
| + ctx.mask_mapping.get(layer.layer_type), | |
| + cache, | |
| + decoder=True, | |
| + offset=offset, | |
| + ) | |
| + h = decoder.norm(h) | |
| + | |
| + if self.config.monotone and step > 0: | |
| + active_idx = mx.array(active_positions, dtype=mx.int32) | |
| + h_active = h[:, active_idx, :] | |
| + else: | |
| + active_idx = None | |
| + h_active = h | |
| + work_tokens += int(h_active.shape[1]) | |
| + | |
| + raw_logits = decoder.embed_tokens.as_linear(h_active) | |
| + vals, idx = _topk_logits(raw_logits, self.config.topk) | |
| + probs, entropy = _topk_postprocess( | |
| + vals, | |
| + softcap, | |
| + _turbo_temperature(ctx, step), | |
| + ) | |
| + proposal = self._sample_from_topk(probs, idx, ctx.temperature).astype( | |
| + mx.int32 | |
| + ) | |
| + top1 = ( | |
| + mx.take_along_axis( | |
| + idx, | |
| + mx.argmax(probs, axis=-1)[..., None], | |
| + axis=-1, | |
| + ) | |
| + .squeeze(-1) | |
| + .astype(mx.int32) | |
| + ) | |
| + | |
| + quota = 0 | |
| + if self.config.quota > 0: | |
| + quota = max(1, int(entropy.shape[-1] * self.config.quota)) | |
| + flush = self.config.steps is not None and step >= self.config.steps - 1 | |
| + accept = self._accept_mask( | |
| + probs=probs, | |
| + entropy=entropy, | |
| + step=step, | |
| + entropy_bound=entropy_bound, | |
| + quota=quota, | |
| + flush=flush, | |
| + active_positions=active_positions, | |
| + canvas_length=ctx.canvas_length, | |
| + ) | |
| + | |
| + if self.config.monotone: | |
| + if active_idx is None: | |
| + committed = mx.where(accept, proposal, committed) | |
| + canvas = mx.where( | |
| + accept, | |
| + committed, | |
| + mx.random.randint(0, ctx.vocab_size, canvas.shape), | |
| + ) | |
| + frozen = frozen | accept | |
| + argmax_full = top1 | |
| + else: | |
| + committed_active = mx.where( | |
| + accept, | |
| + proposal, | |
| + mx.take_along_axis(committed, active_idx[None, :], axis=-1), | |
| + ) | |
| + committed = mx.put_along_axis( | |
| + committed, | |
| + active_idx[None, :], | |
| + committed_active, | |
| + axis=-1, | |
| + ) | |
| + canvas_active = mx.where( | |
| + accept, | |
| + committed_active, | |
| + mx.random.randint(0, ctx.vocab_size, accept.shape), | |
| + ) | |
| + canvas = mx.put_along_axis( | |
| + canvas, | |
| + active_idx[None, :], | |
| + canvas_active, | |
| + axis=-1, | |
| + ) | |
| + frozen = mx.put_along_axis( | |
| + frozen, | |
| + active_idx[None, :], | |
| + mx.take_along_axis(frozen, active_idx[None, :], axis=-1) | |
| + | accept, | |
| + axis=-1, | |
| + ) | |
| + argmax_full = mx.put_along_axis( | |
| + committed.astype(top1.dtype), | |
| + active_idx[None, :], | |
| + top1, | |
| + axis=-1, | |
| + ) | |
| + else: | |
| + committed = mx.where(accept, proposal, canvas) | |
| + canvas = mx.where( | |
| + accept, | |
| + committed, | |
| + mx.random.randint(0, ctx.vocab_size, canvas.shape), | |
| + ) | |
| + argmax_full = top1 | |
| + | |
| + sc_active = self_conditioner.soft_embeddings(probs, idx) | |
| + if self.config.monotone and active_idx is not None: | |
| + base = ( | |
| + sc_embeddings | |
| + if sc_embeddings is not None | |
| + else mx.zeros( | |
| + (ctx.batch_size, ctx.canvas_length, sc_active.shape[-1]), | |
| + dtype=sc_active.dtype, | |
| + ) | |
| + ) | |
| + sc_embeddings = mx.put_along_axis( | |
| + base, | |
| + mx.broadcast_to( | |
| + active_idx[None, :, None], | |
| + (ctx.batch_size, active_idx.size, sc_active.shape[-1]), | |
| + ), | |
| + sc_active, | |
| + axis=1, | |
| + ) | |
| + else: | |
| + sc_embeddings = sc_active | |
| + | |
| + if self.config.monotone: | |
| + mx.eval(frozen, committed, entropy) | |
| + frozen_list = frozen[0].tolist() | |
| + active_positions = [i for i, value in enumerate(frozen_list) if not value] | |
| + mean_entropy = float(mx.mean(entropy).item()) | |
| + | |
| + if self.config.eos_early_stop: | |
| + eos_pos = None | |
| + committed_list = committed[0].tolist() | |
| + for i, (token_id, is_frozen) in enumerate( | |
| + zip(committed_list, frozen_list) | |
| + ): | |
| + if is_frozen and ctx.tokenizer.stopping_criteria(int(token_id)): | |
| + eos_pos = i | |
| + break | |
| + if eos_pos is not None and all(frozen_list[: eos_pos + 1]): | |
| + final_canvas = committed | |
| + emit_length = eos_pos + 1 | |
| + break | |
| + | |
| + if not active_positions: | |
| + final_canvas = committed | |
| + break | |
| + if len(active_positions) == ctx.canvas_length: | |
| + active_positions = list(range(ctx.canvas_length)) | |
| + else: | |
| + mx.eval(argmax_full, entropy) | |
| + mean_entropy = float(mx.mean(entropy).item()) | |
| + | |
| + if prev_argmax is not None and not self.config.monotone: | |
| + if bool(mx.all(argmax_full == prev_argmax).item()): | |
| + stable_count += 1 | |
| + else: | |
| + stable_count = 0 | |
| + if ( | |
| + stable_count >= stability_threshold | |
| + and mean_entropy < confidence_threshold | |
| + ): | |
| + final_canvas = argmax_full | |
| + break | |
| + prev_argmax = argmax_full | |
| + | |
| + displayed_step = step + 1 | |
| + should_show_unmasking = ctx.show_unmasking and ( | |
| + displayed_step == 1 | |
| + or displayed_step == ctx.max_denoising_steps | |
| + or displayed_step % ctx.unmasking_interval == 0 | |
| + or flush | |
| + ) | |
| + if should_show_unmasking: | |
| + if self.config.monotone: | |
| + yield self._draft( | |
| + ctx, | |
| + committed[0].tolist(), | |
| + frozen[0].tolist(), | |
| + displayed_step, | |
| + ) | |
| + else: | |
| + reveal = [True] * ctx.canvas_length if flush else [False] * ctx.canvas_length | |
| + yield self._draft(ctx, argmax_full[0].tolist(), reveal, displayed_step) | |
| + | |
| + if final_canvas is None: | |
| + final_canvas = committed if self.config.monotone else argmax_full | |
| + | |
| + if self.config.repair and emit_length == ctx.canvas_length: | |
| + inputs_embeds = decoder.embed_tokens(final_canvas) * decoder.embed_scale | |
| + soft = ( | |
| + sc_embeddings.astype(inputs_embeds.dtype) | |
| + if sc_embeddings is not None | |
| + else mx.zeros_like(inputs_embeds) | |
| + ) | |
| + h = decoder.self_conditioning(inputs_embeds, soft) | |
| + cache_list = ctx.kv_cache or [None] * len(decoder.layers) | |
| + offset = ( | |
| + cache_list[0].offset | |
| + if cache_list and getattr(cache_list[0], "keys", None) is not None | |
| + else 0 | |
| + ) | |
| + for layer, cache in zip(decoder.layers, cache_list): | |
| + h = layer( | |
| + h, | |
| + ctx.mask_mapping.get(layer.layer_type), | |
| + cache, | |
| + decoder=True, | |
| + offset=offset, | |
| + ) | |
| + h = decoder.norm(h) | |
| + vals, idx = _topk_logits(decoder.embed_tokens.as_linear(h), self.config.topk) | |
| + final_canvas = ( | |
| + mx.take_along_axis(idx, mx.argmax(vals, axis=-1)[..., None], axis=-1) | |
| + .squeeze(-1) | |
| + .astype(mx.int32) | |
| + ) | |
| + steps += 1 | |
| + work_tokens += ctx.canvas_length | |
| + | |
| + yield DiffusionCanvasResult( | |
| + canvas=final_canvas[:, :emit_length], | |
| + canvas_tokens=ctx.canvas_length, | |
| + denoising_steps=steps, | |
| + work_tokens=work_tokens, | |
| + ) | |
| + | |
| + def _compact( | |
| + self, | |
| + ctx: DiffusionCanvasDenoiseContext, | |
| + ) -> Generator[DiffusionCanvasDraft | DiffusionCanvasResult, None, None]: | |
| + if ctx.decoder_attention_mask is not None: | |
| + raise ValueError( | |
| + "turbo_compact does not support padded or static decoder masks yet." | |
| + ) | |
| + | |
| + from ..models.diffusion_gemma.diffusion_turbo_runner import TurboCanvasRunner | |
| + | |
| + decoder = ctx.model.model.decoder | |
| + self_conditioner = _TurboSelfConditioner(decoder) | |
| + cache0 = ctx.kv_cache[0] | |
| + prefix_offset = ( | |
| + int(cache0.offset) if getattr(cache0, "keys", None) is not None else 0 | |
| + ) | |
| + runner = TurboCanvasRunner( | |
| + ctx.model, | |
| + ctx.kv_cache, | |
| + ctx.canvas_length, | |
| + prefix_offset + ctx.canvas_length + 8, | |
| + ) | |
| + softcap = self._softcap(ctx) | |
| + entropy_bound = mx.array( | |
| + self._entropy_bound if self._entropy_bound is not None else ctx.entropy_bound, | |
| + dtype=mx.float32, | |
| + ) | |
| + hidden_size = decoder.config.hidden_size | |
| + canvas_dev = mx.random.randint(0, ctx.vocab_size, (1, ctx.canvas_length)).astype( | |
| + mx.int32 | |
| + ) | |
| + committed_dev = canvas_dev | |
| + sc_full = mx.zeros((1, ctx.canvas_length, hidden_size), dtype=mx.bfloat16) | |
| + frozen = [False] * ctx.canvas_length | |
| + tail_dropped = [False] * ctx.canvas_length | |
| + newly_frozen: list[int] = [] | |
| + steps = 0 | |
| + work_tokens = 0 | |
| + emit_length = ctx.canvas_length | |
| + | |
| + if ctx.show_unmasking: | |
| + yield self._draft( | |
| + ctx, | |
| + canvas_dev[0].tolist(), | |
| + [False] * ctx.canvas_length, | |
| + 0, | |
| + ) | |
| + | |
| + for step in range(ctx.max_denoising_steps): | |
| + active = [i for i in range(ctx.canvas_length) if not frozen[i]] | |
| + if not active: | |
| + break | |
| + forward_positions = sorted(active + newly_frozen) | |
| + steps += 1 | |
| + work_tokens += len(forward_positions) | |
| + | |
| + forward_bucket = _bucket_size(len(forward_positions)) | |
| + forward_bucketed = forward_positions + [forward_positions[-1]] * ( | |
| + forward_bucket - len(forward_positions) | |
| + ) | |
| + active_bucket = _bucket_size(len(active)) | |
| + active_bucketed = active + [active[-1]] * (active_bucket - len(active)) | |
| + | |
| + rel = mx.array(forward_bucketed, dtype=mx.int32) | |
| + runner.set_forward_positions(rel) | |
| + h_forward = runner( | |
| + canvas_dev[:, rel], | |
| + prefix_offset + rel, | |
| + sc_full[:, rel, :], | |
| + ) | |
| + | |
| + pos_in_forward = {p: i for i, p in enumerate(forward_positions)} | |
| + active_idx = mx.array( | |
| + [pos_in_forward[p] for p in active_bucketed], | |
| + dtype=mx.int32, | |
| + ) | |
| + raw_logits = decoder.embed_tokens.as_linear(h_forward[:, active_idx, :]) | |
| + vals, idx = _topk_logits(raw_logits, self.config.topk) | |
| + probs, entropy = _topk_postprocess( | |
| + vals, | |
| + softcap, | |
| + _turbo_temperature(ctx, step), | |
| + ) | |
| + | |
| + real_count = len(active) | |
| + real_probs = probs[:, :real_count, :] | |
| + real_idx = idx[:, :real_count, :] | |
| + real_entropy = entropy[:, :real_count] | |
| + real_active_pos = mx.array(active, dtype=mx.int32) | |
| + proposal = self._sample_from_topk( | |
| + real_probs, | |
| + real_idx, | |
| + ctx.temperature, | |
| + ).astype(mx.int32) | |
| + flush = self.config.steps is not None and step >= self.config.steps - 1 | |
| + accept = self._accept_mask( | |
| + probs=real_probs, | |
| + entropy=real_entropy, | |
| + step=step, | |
| + entropy_bound=entropy_bound, | |
| + quota=0, | |
| + flush=flush, | |
| + ) | |
| + | |
| + noise = mx.random.randint(0, ctx.vocab_size, accept.shape).astype(mx.int32) | |
| + canvas_dev = mx.put_along_axis( | |
| + canvas_dev, | |
| + real_active_pos[None, :], | |
| + mx.where(accept, proposal, noise), | |
| + axis=-1, | |
| + ) | |
| + committed_dev = mx.put_along_axis( | |
| + committed_dev, | |
| + real_active_pos[None, :], | |
| + mx.where( | |
| + accept, | |
| + proposal, | |
| + mx.take_along_axis( | |
| + committed_dev, | |
| + real_active_pos[None, :], | |
| + axis=-1, | |
| + ), | |
| + ), | |
| + axis=-1, | |
| + ) | |
| + sc_active = self_conditioner.soft_embeddings(real_probs, real_idx) | |
| + sc_full = mx.put_along_axis( | |
| + sc_full, | |
| + mx.broadcast_to( | |
| + real_active_pos[None, :, None], | |
| + (1, real_count, hidden_size), | |
| + ), | |
| + sc_active.astype(sc_full.dtype), | |
| + axis=1, | |
| + ) | |
| + | |
| + mx.eval(canvas_dev, committed_dev, sc_full) | |
| + accept_list = accept[0].tolist() | |
| + if not any(accept_list) and not flush: | |
| + best = int(mx.argmax(real_probs.max(axis=-1), axis=-1).item()) | |
| + accept_list[best] = True | |
| + p = active[best] | |
| + token_id = int(proposal[0, best].item()) | |
| + pos = mx.array([[p]], dtype=mx.int32) | |
| + token = mx.array([[token_id]], dtype=mx.int32) | |
| + canvas_dev = mx.put_along_axis(canvas_dev, pos, token, axis=-1) | |
| + committed_dev = mx.put_along_axis(committed_dev, pos, token, axis=-1) | |
| + | |
| + newly_frozen = [] | |
| + for i, accepted in enumerate(accept_list): | |
| + if accepted: | |
| + pos = active[i] | |
| + frozen[pos] = True | |
| + newly_frozen.append(pos) | |
| + | |
| + if self.config.eos_early_stop and newly_frozen: | |
| + committed_now = committed_dev[0].tolist() | |
| + eos_pos = None | |
| + for i, token_id in enumerate(committed_now): | |
| + if frozen[i] and ctx.tokenizer.stopping_criteria(int(token_id)): | |
| + eos_pos = i | |
| + break | |
| + if eos_pos is not None: | |
| + if all(frozen[: eos_pos + 1]): | |
| + emit_length = eos_pos + 1 | |
| + break | |
| + for i in range(eos_pos + 1, ctx.canvas_length): | |
| + if not frozen[i]: | |
| + frozen[i] = True | |
| + tail_dropped[i] = True | |
| + newly_frozen = [p for p in newly_frozen if p <= eos_pos] | |
| + | |
| + displayed_step = step + 1 | |
| + should_show_unmasking = ctx.show_unmasking and ( | |
| + displayed_step == 1 | |
| + or displayed_step == ctx.max_denoising_steps | |
| + or displayed_step % ctx.unmasking_interval == 0 | |
| + or flush | |
| + ) | |
| + if should_show_unmasking: | |
| + committed_view = committed_dev[0].tolist() | |
| + reveal = [ | |
| + frozen[i] and not tail_dropped[i] for i in range(ctx.canvas_length) | |
| + ] | |
| + yield self._draft(ctx, committed_view, reveal, displayed_step) | |
| + | |
| + committed = [int(token_id) for token_id in committed_dev[0].tolist()] | |
| + | |
| + if self.config.repair and emit_length == ctx.canvas_length: | |
| + rel = mx.arange(ctx.canvas_length, dtype=mx.int32) | |
| + runner.set_forward_positions(rel) | |
| + h_forward = runner( | |
| + mx.array([committed], dtype=mx.int32), | |
| + prefix_offset + rel, | |
| + sc_full, | |
| + ) | |
| + vals, idx = _topk_logits( | |
| + decoder.embed_tokens.as_linear(h_forward), | |
| + self.config.topk, | |
| + ) | |
| + repaired = ( | |
| + mx.take_along_axis(idx, mx.argmax(vals, axis=-1)[..., None], axis=-1) | |
| + .squeeze(-1) | |
| + .astype(mx.int32) | |
| + ) | |
| + mx.eval(repaired) | |
| + committed = [int(token_id) for token_id in repaired[0].tolist()] | |
| + steps += 1 | |
| + work_tokens += ctx.canvas_length | |
| + | |
| + yield DiffusionCanvasResult( | |
| + canvas=mx.array([committed[:emit_length]], dtype=mx.int32), | |
| + canvas_tokens=ctx.canvas_length, | |
| + denoising_steps=steps, | |
| + work_tokens=work_tokens, | |
| + ) | |
| diff --git a/mlx_vlm/models/diffusion_gemma/diffusion_turbo_runner.py b/mlx_vlm/models/diffusion_gemma/diffusion_turbo_runner.py | |
| new file mode 100644 | |
| index 0000000..2439411 | |
| --- /dev/null | |
| +++ b/mlx_vlm/models/diffusion_gemma/diffusion_turbo_runner.py | |
| @@ -0,0 +1,238 @@ | |
| +"""Active-set-compacted decoder runner for DiffusionGemma turbo diffusion.""" | |
| + | |
| +from __future__ import annotations | |
| + | |
| +from typing import List, Optional | |
| + | |
| +import mlx.core as mx | |
| + | |
| +from .language import _cache_offset, _cache_state | |
| + | |
| + | |
| +def _rope_tables(dims: int, base: float, positions: int): | |
| + inv_freq = base ** (-(mx.arange(0, dims, 2, dtype=mx.float32) / dims)) | |
| + pos = mx.arange(positions, dtype=mx.float32)[:, None] | |
| + angles = pos * inv_freq[None, :] | |
| + return mx.cos(angles), mx.sin(angles) | |
| + | |
| + | |
| +def _rope_tables_from_freqs(freqs: mx.array, positions: int): | |
| + pos = mx.arange(positions, dtype=mx.float32)[:, None] | |
| + angles = pos / freqs[None, :] | |
| + return mx.cos(angles), mx.sin(angles) | |
| + | |
| + | |
| +def _apply_rope_gathered(x: mx.array, cos: mx.array, sin: mx.array): | |
| + half = x.shape[-1] // 2 | |
| + x1 = x[..., :half].astype(mx.float32) | |
| + x2 = x[..., half:].astype(mx.float32) | |
| + c = cos[None, None, :, :] | |
| + s = sin[None, None, :, :] | |
| + out1 = x1 * c - x2 * s | |
| + out2 = x2 * c + x1 * s | |
| + return mx.concatenate([out1, out2], axis=-1).astype(x.dtype) | |
| + | |
| + | |
| +class TurboCanvasRunner: | |
| + """Forward DiffusionGemma decoder layers over a compacted canvas subset.""" | |
| + | |
| + def __init__(self, model, kv_cache, canvas_length: int, max_position: int): | |
| + self.decoder = model.model.decoder | |
| + self.config = self.decoder.config | |
| + self.layers = self.decoder.layers | |
| + self.kv_cache = kv_cache | |
| + | |
| + cfg = self.config | |
| + sliding_params = cfg.rope_parameters.get("sliding_attention", {}) | |
| + full_params = cfg.rope_parameters.get("full_attention", {}) | |
| + self.cos_s, self.sin_s = _rope_tables( | |
| + cfg.head_dim, | |
| + float(sliding_params.get("rope_theta", 10000.0)), | |
| + max_position, | |
| + ) | |
| + | |
| + global_dim = cfg.global_head_dim or cfg.head_dim | |
| + partial = float(full_params.get("partial_rotary_factor", 1.0)) | |
| + rope_angles = int(partial * global_dim // 2) | |
| + self.full_rotated = 2 * rope_angles | |
| + factor = float(full_params.get("factor", 1.0)) | |
| + base = float(full_params.get("rope_theta", 10000.0)) | |
| + exponents = mx.arange(0, self.full_rotated, 2, dtype=mx.float32) / global_dim | |
| + freqs = factor * (base**exponents) | |
| + self.cos_f, self.sin_f = _rope_tables_from_freqs(freqs, max_position) | |
| + mx.eval(self.cos_s, self.sin_s, self.cos_f, self.sin_f) | |
| + | |
| + self.prefix_k: List[Optional[mx.array]] = [] | |
| + self.prefix_v: List[Optional[mx.array]] = [] | |
| + window = max(cfg.sliding_window - 1, 0) | |
| + for layer, cache in zip(self.layers, kv_cache): | |
| + state = _cache_state(cache) | |
| + if state is None: | |
| + self.prefix_k.append(None) | |
| + self.prefix_v.append(None) | |
| + continue | |
| + keys, values = state | |
| + offset = _cache_offset(cache) | |
| + if layer.layer_type == "sliding_attention": | |
| + encoder_len = keys.shape[2] | |
| + if window and encoder_len > window and offset >= encoder_len: | |
| + keys = keys[:, :, -window:, :] | |
| + values = values[:, :, -window:, :] | |
| + self.prefix_k.append(keys) | |
| + self.prefix_v.append(values) | |
| + | |
| + dtype = ( | |
| + self.decoder.embed_tokens.scales.dtype | |
| + if hasattr(self.decoder.embed_tokens, "scales") | |
| + else mx.bfloat16 | |
| + ) | |
| + self.buf_k: List[mx.array] = [] | |
| + self.buf_v: List[mx.array] = [] | |
| + for layer in self.layers: | |
| + attn = layer.self_attn | |
| + shape = (1, attn.n_kv_heads, canvas_length, attn.head_dim) | |
| + self.buf_k.append(mx.zeros(shape, dtype=dtype)) | |
| + self.buf_v.append(mx.zeros(shape, dtype=dtype)) | |
| + self._scatter_idx = [] | |
| + | |
| + def set_forward_positions(self, canvas_rel_positions: mx.array): | |
| + self._scatter_idx = [] | |
| + for layer in self.layers: | |
| + attn = layer.self_attn | |
| + idx = mx.broadcast_to( | |
| + canvas_rel_positions[None, None, :, None], | |
| + (1, attn.n_kv_heads, canvas_rel_positions.size, attn.head_dim), | |
| + ) | |
| + self._scatter_idx.append(idx) | |
| + | |
| + def __call__( | |
| + self, | |
| + tokens_f: mx.array, | |
| + positions_f: mx.array, | |
| + sc_f: Optional[mx.array], | |
| + ): | |
| + decoder = self.decoder | |
| + inputs_embeds = decoder.embed_tokens(tokens_f) * decoder.embed_scale | |
| + soft = ( | |
| + sc_f.astype(inputs_embeds.dtype) | |
| + if sc_f is not None | |
| + else mx.zeros_like(inputs_embeds) | |
| + ) | |
| + h = decoder.self_conditioning(inputs_embeds, soft) | |
| + | |
| + batch_size, forward_count, _ = h.shape | |
| + cos_s = self.cos_s[positions_f] | |
| + sin_s = self.sin_s[positions_f] | |
| + cos_f = self.cos_f[positions_f] | |
| + sin_f = self.sin_f[positions_f] | |
| + | |
| + for layer_index, layer in enumerate(self.layers): | |
| + attn = layer.self_attn | |
| + residual = h | |
| + h_norm = layer.input_layernorm(h) | |
| + | |
| + q = attn.q_proj(h_norm).reshape( | |
| + batch_size, | |
| + forward_count, | |
| + attn.n_heads, | |
| + attn.head_dim, | |
| + ) | |
| + q = attn.q_norm(q).transpose(0, 2, 1, 3) | |
| + k = attn.k_proj(h_norm).reshape( | |
| + batch_size, | |
| + forward_count, | |
| + attn.n_kv_heads, | |
| + attn.head_dim, | |
| + ) | |
| + v_raw = ( | |
| + attn.v_proj(h_norm).reshape( | |
| + batch_size, | |
| + forward_count, | |
| + attn.n_kv_heads, | |
| + attn.head_dim, | |
| + ) | |
| + if attn.v_proj is not None | |
| + else k | |
| + ) | |
| + k = attn.k_norm(k).transpose(0, 2, 1, 3) | |
| + v = attn.v_norm(v_raw).transpose(0, 2, 1, 3) | |
| + | |
| + if attn.is_sliding: | |
| + q = _apply_rope_gathered(q, cos_s, sin_s) | |
| + k = _apply_rope_gathered(k, cos_s, sin_s) | |
| + elif self.full_rotated: | |
| + rotated_dims = self.full_rotated | |
| + half = attn.head_dim // 2 | |
| + rotary_half = rotated_dims // 2 | |
| + | |
| + def proportional_rope(x): | |
| + left = x[..., :half] | |
| + right = x[..., half:] | |
| + rotary = mx.concatenate( | |
| + [left[..., :rotary_half], right[..., :rotary_half]], | |
| + axis=-1, | |
| + ) | |
| + rotary = _apply_rope_gathered(rotary, cos_f, sin_f) | |
| + left = mx.concatenate( | |
| + [rotary[..., :rotary_half], left[..., rotary_half:]], | |
| + axis=-1, | |
| + ) | |
| + right = mx.concatenate( | |
| + [rotary[..., rotary_half:], right[..., rotary_half:]], | |
| + axis=-1, | |
| + ) | |
| + return mx.concatenate([left, right], axis=-1) | |
| + | |
| + q = proportional_rope(q) | |
| + k = proportional_rope(k) | |
| + | |
| + scatter_idx = self._scatter_idx[layer_index] | |
| + self.buf_k[layer_index] = mx.put_along_axis( | |
| + self.buf_k[layer_index], | |
| + scatter_idx, | |
| + k.astype(self.buf_k[layer_index].dtype), | |
| + axis=2, | |
| + ) | |
| + self.buf_v[layer_index] = mx.put_along_axis( | |
| + self.buf_v[layer_index], | |
| + scatter_idx, | |
| + v.astype(self.buf_v[layer_index].dtype), | |
| + axis=2, | |
| + ) | |
| + | |
| + keys = self.buf_k[layer_index] | |
| + values = self.buf_v[layer_index] | |
| + if self.prefix_k[layer_index] is not None: | |
| + keys = mx.concatenate([self.prefix_k[layer_index], keys], axis=2) | |
| + values = mx.concatenate([self.prefix_v[layer_index], values], axis=2) | |
| + | |
| + out = mx.fast.scaled_dot_product_attention( | |
| + q, | |
| + keys, | |
| + values, | |
| + scale=1.0, | |
| + mask=None, | |
| + ) | |
| + out = out.transpose(0, 2, 1, 3).reshape(batch_size, forward_count, -1) | |
| + h_norm = attn.o_proj(out) | |
| + | |
| + h_norm = layer.post_attention_layernorm(h_norm) | |
| + h = residual + h_norm | |
| + | |
| + residual = h | |
| + h1 = layer.pre_feedforward_layernorm(h) | |
| + h1 = layer.mlp(h1) | |
| + h1 = layer.post_feedforward_layernorm_1(h1) | |
| + | |
| + flat = residual.reshape(-1, residual.shape[-1]) | |
| + top_k_indices, top_k_weights = layer.router(flat) | |
| + h2 = layer.pre_feedforward_layernorm_2(flat) | |
| + h2 = layer.experts(h2, top_k_indices, top_k_weights) | |
| + h2 = h2.reshape(residual.shape) | |
| + h2 = layer.post_feedforward_layernorm_2(h2) | |
| + | |
| + h = layer.post_feedforward_layernorm(h1 + h2) | |
| + h = residual + h | |
| + h = h * layer.layer_scalar | |
| + | |
| + return decoder.norm(h) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment