UnhurriedDawn's picture
paper
0e38dc9
# app.py
import torch
import torch.nn.functional as F
import torch.distributions as dists
import transformers
from transformers import AutoTokenizer
from peft import PeftModel, PeftConfig
import numpy as np
import random
import time
import os
from typing import List, Dict, Optional, Tuple, Iterator, Set
import gradio as gr
import spaces # 导入 spaces 模块
# Suppress some Hugging Face warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Import necessary model classes from the local directory
# Make sure the 'model_cache' directory is in your Hugging Face Space repository.
from model_cache.llada.modeling_llada import LLaDAModelLM
from model_cache.llada.configuration_llada import LLaDAConfig
# --- Helper Functions (Unchanged) ---
def set_seed(seed):
torch.manual_seed(seed); random.seed(seed); np.random.seed(seed);
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed); torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False
def create_full_block_attention_mask(prompt_length, max_length, block_size, device=None, dtype=None):
if dtype is None: dtype = torch.bfloat16
attention_mask = torch.full((1, 1, max_length, max_length), -torch.inf, device=device, dtype=dtype)
attention_mask[:, :, :prompt_length, :prompt_length] = 0
remaining_length = max_length - prompt_length
num_blocks = (remaining_length + block_size - 1) // block_size
for b in range(num_blocks):
block_start = prompt_length + b * block_size; block_end = min(prompt_length + (b + 1) * block_size, max_length)
attention_mask[:, :, block_start:block_end, :prompt_length] = 0
for prev_b in range(b):
prev_start = prompt_length + prev_b * block_size; prev_end = min(prompt_length + (prev_b + 1) * block_size, max_length)
attention_mask[:, :, block_start:block_end, prev_start:prev_end] = 0
attention_mask[:, :, block_start:block_end, block_start:block_end] = 0
return attention_mask
def extract_attention_mask(full_mask, start_pos, input_length, cache_length):
end_pos = start_pos + input_length; total_length = cache_length + input_length
extracted_mask = torch.full((1, 1, input_length, total_length), -torch.inf, device=full_mask.device, dtype=full_mask.dtype)
extracted_mask[:, :, :, :cache_length] = full_mask[:, :, start_pos:end_pos, :cache_length]
extracted_mask[:, :, :, cache_length:] = full_mask[:, :, start_pos:end_pos, start_pos:end_pos]
return extracted_mask
def top_p_logits(logits, top_p=None):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
return logits
def top_k_logits(logits, top_k=None):
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
return logits
def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
if temperature > 0: logits = logits / temperature
if top_p is not None and top_p < 1: logits = top_p_logits(logits, top_p)
if top_k is not None: logits = top_k_logits(logits, top_k)
probs = torch.softmax(logits, dim=-1)
if temperature > 0:
try:
x0 = dists.Categorical(probs=probs).sample()
initial_confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
except: initial_confidence, x0 = probs.max(dim=-1)
else: initial_confidence, x0 = probs.max(dim=-1)
confidence = initial_confidence.clone()
if margin_confidence:
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
confidence = sorted_probs[:, 0] - sorted_probs[:, 1]
if neg_entropy:
epsilon = 1e-10
confidence = torch.sum(probs * torch.log(probs + epsilon), dim=-1)
return confidence, x0, initial_confidence
class DreamLoRAInference:
# CSS is exactly the same as your original script
CSS = """
/* Enhanced modern styling */
.main-container {
max-width: 1400px;
margin: 0 auto;
padding: 20px;
}
.output-text-container {
background: linear-gradient(135deg, #f8fafc 0%, #f1f5f9 100%);
border: 2px solid #e2e8f0;
border-radius: 12px;
padding: 20px;
margin: 15px 0;
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
}
.output-text-container textarea {
background: white !important;
border: 1px solid #cbd5e1 !important;
border-radius: 8px !important;
font-family: 'Inter', 'Segoe UI', sans-serif !important;
font-size: 14px !important;
line-height: 1.6 !important;
padding: 16px !important;
box-shadow: inset 0 2px 4px 0 rgba(0, 0, 0, 0.06) !important;
}
.stats-card {
background: linear-gradient(135deg, #ecfdf5 0%, #f0fdf4 100%);
border: 2px solid #10b981;
border-radius: 12px;
padding: 20px;
margin: 15px 0;
box-shadow: 0 4px 6px -1px rgba(16, 185, 129, 0.1);
}
.stats-card h3 {
color: #065f46;
margin-top: 0;
margin-bottom: 15px;
font-weight: 600;
}
.stats-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
gap: 12px;
margin-top: 10px;
}
.stat-item {
background: white;
padding: 12px 16px;
border-radius: 8px;
border-left: 4px solid #10b981;
box-shadow: 0 1px 3px 0 rgba(0, 0, 0, 0.1);
}
.stat-label {
font-size: 12px;
color: #6b7280;
text-transform: uppercase;
letter-spacing: 0.5px;
margin-bottom: 4px;
}
.stat-value {
font-size: 18px;
font-weight: 600;
color: #065f46;
font-family: 'Monaco', 'Menlo', monospace;
}
.viz-container {
background: linear-gradient(135deg, #fefefe 0%, #f9fafb 100%);
border: 2px solid #e5e7eb;
border-radius: 12px;
padding: 20px;
margin: 15px 0;
height: 600px;
overflow-y: auto;
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
position: relative;
}
.viz-header {
background: linear-gradient(135deg, #3b82f6 0%, #2563eb 100%);
color: white;
padding: 12px 20px;
margin: -20px -20px 20px -20px;
border-radius: 12px 12px 0 0;
font-weight: 600;
font-size: 16px;
display: flex;
align-items: center;
gap: 8px;
}
.viz-header::before {
content: "🎬";
font-size: 18px;
}
.block-container {
display: inline-block;
border: 2px solid transparent;
border-radius: 10px;
padding: 8px;
margin: 6px 2px;
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
background: rgba(255, 255, 255, 0.8);
backdrop-filter: blur(10px);
}
.block-updating {
border-color: #ff4500 !important;
box-shadow: 0 0 20px rgba(255, 69, 0, 0.4);
transform: scale(1.02);
background: rgba(255, 245, 238, 0.9) !important;
}
.token {
padding: 4px 8px;
margin: 2px;
border-radius: 6px;
display: inline-block;
line-height: 1.5;
font-family: 'Monaco', 'Menlo', monospace;
font-size: 13px;
font-weight: 500;
transition: all 0.2s ease;
}
.token:hover {
transform: translateY(-1px);
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.15);
}
.token.prompt { background: linear-gradient(135deg, #f1f5f9 0%, #e2e8f0 100%); color: #475569; border: 1px solid #cbd5e1; }
.token.gen-0 { background: linear-gradient(135deg, #dbeafe 0%, #bfdbfe 100%); color: #1e40af; border: 1px solid #60a5fa; }
.token.gen-1 { background: linear-gradient(135deg, #d1fae5 0%, #a7f3d0 100%); color: #065f46; border: 1px solid #34d399; }
.token.gen-2 { background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%); color: #92400e; border: 1px solid #fbbf24; }
.token.gen-3 { background: linear-gradient(135deg, #fee2e2 0%, #fecaca 100%); color: #991b1b; border: 1px solid #f87171; }
.token.gen-4 { background: linear-gradient(135deg, #e0e7ff 0%, #c7d2fe 100%); color: #3730a3; border: 1px solid #818cf8; }
.token.gen-5 { background: linear-gradient(135deg, #f3e8ff 0%, #e9d5ff 100%); color: #6b21a8; border: 1px solid #c084fc; }
.token.mask {
background: linear-gradient(135deg, #f9fafb 0%, #f3f4f6 100%);
color: #9ca3af;
border: 2px dashed #d1d5db;
animation: pulse 2s infinite;
}
@keyframes pulse {
0%, 100% { opacity: 1; }
50% { opacity: 0.6; }
}
.control-button {
background: linear-gradient(135deg, #8b5cf6 0%, #7c3aed 100%) !important;
border: none !important;
color: white !important;
padding: 12px 24px !important;
border-radius: 10px !important;
font-weight: 600 !important;
font-size: 14px !important;
box-shadow: 0 4px 6px -1px rgba(139, 92, 246, 0.3) !important;
transition: all 0.3s ease !important;
display: flex !important;
align-items: center !important;
gap: 8px !important;
margin: 10px 0 !important;
}
.control-button:hover {
transform: translateY(-2px) !important;
box-shadow: 0 8px 15px -3px rgba(139, 92, 246, 0.4) !important;
}
.control-button:active {
transform: translateY(0) !important;
}
.control-button::before {
content: "🎮";
font-size: 16px;
}
.param-card {
background: white;
border: 1px solid #e5e7eb;
border-radius: 10px;
padding: 16px;
margin: 8px 0;
box-shadow: 0 1px 3px 0 rgba(0, 0, 0, 0.1);
}
.viz-container::-webkit-scrollbar {
width: 12px;
}
.viz-container::-webkit-scrollbar-track {
background: #f1f5f9;
border-radius: 6px;
}
.viz-container::-webkit-scrollbar-thumb {
background: linear-gradient(135deg, #94a3b8 0%, #64748b 100%);
border-radius: 6px;
border: 2px solid #f1f5f9;
}
.viz-container::-webkit-scrollbar-thumb:hover {
background: linear-gradient(135deg, #64748b 0%, #475569 100%);
}
.generating-indicator {
display: inline-flex;
align-items: center;
gap: 8px;
color: #6366f1;
font-weight: 500;
}
.generating-indicator::after {
content: "";
width: 12px;
height: 12px;
border: 2px solid #6366f1;
border-top: 2px solid transparent;
border-radius: 50%;
animation: spin 1s linear infinite;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
@media (max-width: 768px) {
.main-container {
padding: 10px;
}
.stats-grid {
grid-template-columns: 1fr;
}
.viz-container {
height: 400px;
}
}
"""
def __init__(self, **kwargs):
print("Initializing DreamLoRAInference...")
# Lazy loading: store config, don't load model yet
self.config = kwargs
self.model = None
self.tokenizer = None
if self.config.get("dtype") == "bfloat16" and torch.cuda.is_bf16_supported():
self.target_dtype = torch.bfloat16
elif self.config.get("dtype") == "float16":
self.target_dtype = torch.float16
else:
self.target_dtype = torch.float32
# Set attributes from config
for key, value in kwargs.items():
setattr(self, key, value)
print("DreamLoRAInference configured. Model will be loaded on first use.")
def _ensure_model_loaded(self):
"""Load model and tokenizer if they haven't been loaded yet."""
if self.model is None:
print("Loading model and tokenizer for the first time...")
self._setup_model(self.config["pretrained_path"], self.config["lora_path"])
print("Model and tokenizer setup complete.")
def _setup_model(self, pretrained_path, lora_path):
config = LLaDAConfig.from_pretrained(pretrained_path)
self.model = LLaDAModelLM.from_pretrained(
pretrained_path,
config=config,
torch_dtype=self.target_dtype,
device_map="auto" # Use device_map for auto hardware assignment
).eval()
self.model = PeftModel.from_pretrained(self.model, lora_path)
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def _apply_chat_template(self, prompt):
chat_history = [{"role": "user", "content": prompt}]
return self.tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True)
def _update_block_completion_states(self, block_states, decoded_token_threshold):
for block_id in sorted(block_states.keys()):
if 'total_masks' in block_states[block_id] and block_states[block_id]['total_masks'] > 0:
decoded_tokens = block_states[block_id]['total_masks'] - block_states[block_id]['mask_count']
decode_ratio = decoded_tokens / block_states[block_id]['total_masks']
if decode_ratio >= decoded_token_threshold:
if (next_block_id := block_id + 1) in block_states:
block_states[next_block_id]['is_complete'] = True
def _render_visualization_html(self, step: int, x_t: torch.Tensor, block_states: Dict, cache_length: int, updated_block_ids: Set[int]) -> str:
# This function is identical to your original one, with HTML escaping.
timestamp = int(time.time() * 1000)
html_parts = []
html_parts.append('<div class="viz-header">Slow-Motion Generation Process</div>')
for block_id in sorted(k for k in block_states.keys() if k > 0):
state = block_states[block_id]
container_classes = ["block-container"]
if block_id in updated_block_ids: container_classes.append("block-updating")
html_parts.append(f'<div class="{" ".join(container_classes)}" id="block-{block_id}-{timestamp}">')
block_tokens = x_t[0, state['start_pos']:state['end_pos']]
for token_id in block_tokens:
token_id_int = token_id.item()
token_classes = ["token"]
if token_id_int == self.mask_token_id:
token_str = '░'; token_classes.append("mask")
else:
token_str = self.tokenizer.decode([token_id_int], skip_special_tokens=False)
token_str = token_str.replace('&', '&').replace('<', '<').replace('>', '>')
token_classes.append(f"gen-{(block_id - 1) % 6}")
html_parts.append(f'<span class="{" ".join(token_classes)}">{token_str}</span>')
html_parts.append('</div>')
html_parts.append(f'<div class="scroll-anchor" id="viz-anchor-{timestamp}"></div>')
# Script part from original for scrolling
complete_html = f"""
<div class="viz-content" id="viz-content-{timestamp}">
{''.join(html_parts)}
</div>
<script>
(function() {{
const container = document.querySelector('.viz-container');
if (container) {{ container.scrollTop = container.scrollHeight; }}
}})();
</script>
"""
return complete_html
@spaces.GPU
@torch.inference_mode()
def stream_and_capture_for_gradio(
self,
prompt_text: str,
max_new_tokens: int,
block_size: int,
block_add_threshold: float,
decoded_token_threshold: float,
skip_threshold: float
) -> Iterator[Tuple[str, List[str], str, bool]]:
# This is the core generation algorithm, now identical to your original script
self._ensure_model_loaded()
start_time = time.time()
captured_frames: List[str] = []
input_ids = self.tokenizer(self._apply_chat_template(prompt_text), return_tensors="pt").input_ids.to(self.model.device)
prompt_length = input_ids.shape[1]
full_attention_mask = create_full_block_attention_mask(prompt_length, self.max_length, block_size, self.model.device, self.target_dtype)
x_t = input_ids
block_states = {0: {'start_pos': 0, 'end_pos': prompt_length, 'mask_count': 0, 'total_masks': prompt_length, 'state': 'to_cache', 'is_complete': True}}
past_key_values, current_blocks, step, eos_detected, cache_length = None, 0, 0, False, 0
initial_viz_html = self._render_visualization_html(0, x_t, block_states, 0, set())
captured_frames.append(initial_viz_html)
yield "", captured_frames, '<div class="generating-indicator">Initializing generation process...</div>', False
while True:
step += 1
updated_block_ids: Set[int] = set()
if len(block_states) - 1 < (max_new_tokens // block_size) and not eos_detected:
last_block_id = max(block_states.keys())
progress = (block_states[last_block_id]['total_masks'] - block_states[last_block_id]['mask_count']) / block_states[last_block_id]['total_masks'] if block_states[last_block_id]['total_masks'] > 0 else 1.0
if progress >= block_add_threshold:
new_block_id = last_block_id + 1; new_start_pos = x_t.shape[1]
if new_start_pos + block_size <= self.max_length:
x_t = torch.cat([x_t, torch.full((1, block_size), self.mask_token_id, device=self.model.device, dtype=torch.long)], dim=1)
block_states[new_block_id] = {'start_pos': new_start_pos, 'end_pos': new_start_pos + block_size, 'mask_count': block_size, 'total_masks': block_size, 'state': 'active', 'is_complete': False}
current_blocks += 1
self._update_block_completion_states(block_states, decoded_token_threshold)
if (x_t == self.mask_token_id).sum() == 0 and current_blocks == 0: break
blocks_to_cache = [bid for bid, state in block_states.items() if state['state'] == 'to_cache']
update_kvcache = 0
if blocks_to_cache:
start_pos, end_pos = block_states[min(blocks_to_cache)]['start_pos'], block_states[max(blocks_to_cache)]['end_pos']
update_kvcache = end_pos - start_pos; input_seq, process_start_pos = x_t[:, start_pos:], start_pos
else:
active_blocks = [bid for bid, state in block_states.items() if state['state'] == 'active' and state['start_pos'] >= cache_length]
if not active_blocks: break
start_pos = min(block_states[bid]['start_pos'] for bid in active_blocks); input_seq, process_start_pos = x_t[:, start_pos:], start_pos
if input_seq.shape[1] == 0: break
attention_mask = extract_attention_mask(full_attention_mask, process_start_pos, input_seq.shape[1], cache_length)
outputs = self.model(input_seq, attention_bias=attention_mask, past_key_values=past_key_values, use_cache=True, update_kvcache=update_kvcache + cache_length)
if update_kvcache > 0:
past_key_values = outputs.past_key_values
for bid in blocks_to_cache: block_states[bid]['state'] = 'in_cache'
blocks_to_deactivate = []
for block_id, state in block_states.items():
if state['state'] != 'active': continue
block_mask_locs = (x_t[0, state['start_pos']:state['end_pos']] == self.mask_token_id).nonzero().squeeze(-1)
if block_mask_locs.numel() == 0:
blocks_to_deactivate.append(block_id); continue
logit_offset = state['start_pos'] - process_start_pos
block_mask_logits = outputs.logits[:, logit_offset + block_mask_locs, :]
_, x0, initial_confidence = sample_tokens(block_mask_logits.squeeze(0), self.temperature, self.top_p, self.top_k)
all_indices = (initial_confidence > skip_threshold).nonzero().squeeze(-1)
if state['is_complete'] and all_indices.numel() == 0 and block_mask_logits.numel() > 0:
all_indices = torch.tensor([torch.argmax(initial_confidence)], device=self.model.device)
if all_indices.numel() > 0:
updated_block_ids.add(block_id)
positions_to_update = state['start_pos'] + block_mask_locs[all_indices]
x_t[0, positions_to_update] = x0[all_indices]; state['mask_count'] -= all_indices.numel()
if self.tokenizer.eos_token_id in x0[all_indices]: eos_detected = True
if state['mask_count'] == 0: blocks_to_deactivate.append(block_id)
for bid in blocks_to_deactivate:
if block_states[bid]['state'] == 'active' and all(block_states.get(i, {}).get('state') != 'active' for i in range(bid)):
block_states[bid]['state'] = 'to_cache'; current_blocks -= 1
if update_kvcache > 0: cache_length += update_kvcache
generated_ids = x_t[0, prompt_length:]
valid_ids = generated_ids[generated_ids != self.mask_token_id]
live_text = self.tokenizer.decode(valid_ids, skip_special_tokens=True)
current_viz_html = self._render_visualization_html(step, x_t, block_states, cache_length, updated_block_ids)
captured_frames.append(current_viz_html)
yield live_text, captured_frames, f'<div class="generating-indicator">Generating... Step {step}</div>', False
total_time = time.time() - start_time
final_generated_ids = x_t[0, prompt_length:]
eos_positions = (final_generated_ids == self.tokenizer.eos_token_id).nonzero()
if eos_positions.numel() > 0:
final_generated_ids = final_generated_ids[:eos_positions[0, 0] + 1]
final_text = self.tokenizer.decode(final_generated_ids, skip_special_tokens=True)
final_viz_html = self._render_visualization_html(step, x_t, block_states, cache_length, set())
captured_frames.append(final_viz_html)
tokens_incl_eos = len(final_generated_ids)
tokens_excl_eos = len(final_generated_ids[final_generated_ids != self.tokenizer.eos_token_id])
stats_html = f"""
<div class="stats-card">
<h3>✅ Generation Complete!</h3>
<div class="stats-grid">
<div class="stat-item">
<div class="stat-label">Total Time</div>
<div class="stat-value">{total_time:.2f}s</div>
</div>
<div class="stat-item">
<div class="stat-label">Tokens (incl. EOS)</div>
<div class="stat-value">{tokens_incl_eos}</div>
</div>
<div class="stat-item">
<div class="stat-label">Tokens (excl. EOS)</div>
<div class="stat-value">{tokens_excl_eos}</div>
</div>
<div class="stat-item">
<div class="stat-label">Tokens/Second</div>
<div class="stat-value">{(tokens_incl_eos / total_time):.1f}</div>
</div>
</div>
</div>
"""
yield final_text, captured_frames, stats_html, True
# --- Gradio UI and Event Handlers ---
if __name__ == "__main__":
# Use Hugging Face Hub model IDs
config = {
"pretrained_path": "GSAI-ML/LLaDA-8B-Instruct",
"lora_path": "SJTU-Deng-Lab/D2F_LLaDA_Instruct_8B_Lora",
"dtype": "bfloat16", "max_length": 4096,
"temperature": 0.0, "top_p": None, "top_k": None, "mask_token_id": 126336,
"sampling_strategy": "default",
}
set_seed(42)
inference_engine = DreamLoRAInference(**config)
def animate_visualization(html_frames_list: List[str], delay: float) -> Iterator[str]:
if not html_frames_list:
yield '<div class="viz-header">No visualization data captured</div>'
return
for frame in html_frames_list:
yield frame
time.sleep(delay)
# Simplified auto-scroll JS from your original script
auto_scroll_js = """
<script>
function setupAutoScroll(containerSelector, contentSelector) {
const container = document.querySelector(containerSelector);
if (!container) return;
const observer = new MutationObserver(() => {
container.scrollTop = container.scrollHeight;
});
observer.observe(container, {
childList: true,
subtree: true
});
}
document.addEventListener('DOMContentLoaded', () => {
// Use a timeout to ensure Gradio elements are rendered
setTimeout(() => {
setupAutoScroll('#live-text-output', 'textarea');
setupAutoScroll('.viz-container', '.viz-content');
}, 1500);
});
</script>
"""
with gr.Blocks(css=DreamLoRAInference.CSS, theme=gr.themes.Soft(), title="D2F-LLaDA Visualization") as demo:
html_frames_state = gr.State([])
generation_complete_state = gr.State(False)
gr.HTML(auto_scroll_js) # Keep the JS injection
# The entire UI layout is now identical to your original script
with gr.Column(elem_classes=["main-container"]):
gr.Markdown("# ✨ D2F: Faster-than-AR Inference for Diffusion LLMs")
gr.Markdown(
"""
[GitHub](https://github.com/zhijie-group/Discrete-Diffusion-Forcing) | [📜 Paper](https://arxiv.org/abs/2508.09192) | [🌐 Blog Post](https://zhijie-group.github.io/Discrete-Diffusion-Forcing/) | [🤗 D2F-LLaDA LoRA](https://huggingface.co/SJTU-Deng-Lab/D2F_LLaDA_Instruct_8B_Lora) | [🤗 D2F-Dream LoRA](https://huggingface.co/SJTU-Deng-Lab/D2F_Dream_Base_7B_Lora)
"""
)
gr.Markdown(
"""
This demo showcases **Discrete Diffusion Forcing (D2F)**, a novel framework that enables Diffusion Language Models (dLLMs) to achieve faster-than-autoregressive inference speeds for the first time. D2F creates an AR-Diffusion hybrid paradigm that combines the efficiency of KV Caching with inter-block parallel decoding.
The model powering this demo is **LLaDA-Instruct-8B**, fine-tuned with our D2F method. Watch its unique block-wise generation in real-time, then replay the process in slow motion to see how it works!
"""
)
with gr.Row():
with gr.Column(scale=2):
prompt_input = gr.Textbox(
label="🤔 Enter your question",
placeholder="Ask me anything! Try: 'Explain quantum physics' or 'Write a story about...'",
lines=4,
elem_classes=["param-card"]
)
with gr.Accordion("⚙️ Advanced Settings", open=False):
with gr.Row():
max_new_tokens_slider = gr.Slider(minimum=64, maximum=2048, value=2048, step=64, label="Max Tokens", info="Maximum number of tokens to generate")
block_size_slider = gr.Slider(minimum=16, maximum=128, value=32, step=16, label="Block Size", info="Size of each generation block")
with gr.Row():
block_add_thresh_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.05, label="Block Add Threshold", info="When to add new blocks")
decoded_token_thresh_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Completion Threshold", info="Block completion criteria")
with gr.Row():
skip_thresh_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.01, label="Skip Threshold", info="Token selection threshold")
delay_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.15, step=0.05, label="Playback Speed", info="Slow-motion playback delay (seconds)")
generate_button = gr.Button("🚀 Generate Text", variant="primary", size="lg")
with gr.Column(scale=3):
with gr.Group(elem_classes=["output-text-container"]):
gr.Markdown("### 📝 Generated Text (Real-time)")
live_text_output = gr.Textbox(label="", interactive=False, lines=15, show_label=False, placeholder="Generated text will appear here as the AI thinks...", elem_id="live-text-output")
stats_output = gr.HTML(elem_id="stats-output")
with gr.Row():
with gr.Column():
slowmo_button = gr.Button("🎬 Watch Slow-Motion Generation Process", variant="secondary", size="lg", elem_classes=["control-button"], visible=False, interactive=False)
with gr.Group(elem_classes=["viz-container"], visible=False) as viz_group:
visualization_output = gr.HTML(label="")
# Examples are identical to your original script
gr.Examples(
examples=[
["A circular swimming pool has a diameter of 8 meters. Calculate the pool's circumference and area. First, explain the relationship between diameter, radius, circumference, and area of a circle, including the role of π in these formulas. Then perform the calculations using π ≈ 3.14159. Next, estimate how much water (in cubic meters) would be needed to fill this pool if it has a uniform depth of 1.5 meters. Finally, calculate how much it would cost to fill this pool if water costs $2.50 per cubic meter. Show all steps and include appropriate units in your answer.", 2048, 32, 0.1, 0.5, 0.9, 0.1],
["A movie theater offers a loyalty card that costs $15 and gives a 15% discount on all tickets. If a regular movie ticket costs $10, how many tickets would you need to buy to make the loyalty card worthwhile? First, explain the concept of a break-even point. Then set up an equation to find when the total cost with the card equals the total cost without the card. Solve this equation step by step, showing all your work. Finally, interpret your answer in the context of the problem.", 2048, 32, 0.1, 0.5, 0.9, 0.1],
["Solve the equation x² - 6x + 8 = 0. First, explain what a quadratic equation is and why it can have up to two solutions. Then solve this equation using three different methods: factoring, completing the square, and the quadratic formula. For each method, explain the mathematical reasoning behind it, show all steps in detail, and discuss when this particular method is most useful. Finally, verify your solutions by substituting them back into the original equation.", 2048, 32, 0.1, 0.55, 0.9, 0.1],
],
inputs=[prompt_input, max_new_tokens_slider, block_size_slider, block_add_thresh_slider, decoded_token_thresh_slider, skip_thresh_slider, delay_slider],
label="💡 Try these examples"
)
# Event handling is now identical to your original, correct script
def update_slowmo_button_visibility(is_complete):
return gr.update(visible=is_complete, interactive=is_complete)
def show_visualization():
return gr.update(visible=True)
inputs_list = [
prompt_input, max_new_tokens_slider, block_size_slider,
block_add_thresh_slider, decoded_token_thresh_slider, skip_thresh_slider
]
# This is the original, correct event chain
generation_event = generate_button.click(
fn=lambda: [gr.update(visible=False, interactive=False), gr.update(visible=False), gr.update(value=None), gr.update(value="")],
outputs=[slowmo_button, viz_group, stats_output, live_text_output]
).then(
fn=inference_engine.stream_and_capture_for_gradio,
inputs=inputs_list,
outputs=[live_text_output, html_frames_state, stats_output, generation_complete_state]
).then(
fn=update_slowmo_button_visibility,
inputs=[generation_complete_state],
outputs=[slowmo_button]
)
slowmo_event = slowmo_button.click(
fn=show_visualization,
outputs=[viz_group]
).then(
fn=animate_visualization,
inputs=[html_frames_state, delay_slider],
outputs=[visualization_output]
)
demo.queue().launch()