Spaces:
Running
on
Zero
Running
on
Zero
# app.py | |
import torch | |
import torch.nn.functional as F | |
import torch.distributions as dists | |
import transformers | |
from transformers import AutoTokenizer | |
from peft import PeftModel, PeftConfig | |
import numpy as np | |
import random | |
import time | |
import os | |
from typing import List, Dict, Optional, Tuple, Iterator, Set | |
import gradio as gr | |
import spaces # 导入 spaces 模块 | |
# Suppress some Hugging Face warnings | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
# Import necessary model classes from the local directory | |
# Make sure the 'model_cache' directory is in your Hugging Face Space repository. | |
from model_cache.llada.modeling_llada import LLaDAModelLM | |
from model_cache.llada.configuration_llada import LLaDAConfig | |
# --- Helper Functions (Unchanged) --- | |
def set_seed(seed): | |
torch.manual_seed(seed); random.seed(seed); np.random.seed(seed); | |
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed); torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False | |
def create_full_block_attention_mask(prompt_length, max_length, block_size, device=None, dtype=None): | |
if dtype is None: dtype = torch.bfloat16 | |
attention_mask = torch.full((1, 1, max_length, max_length), -torch.inf, device=device, dtype=dtype) | |
attention_mask[:, :, :prompt_length, :prompt_length] = 0 | |
remaining_length = max_length - prompt_length | |
num_blocks = (remaining_length + block_size - 1) // block_size | |
for b in range(num_blocks): | |
block_start = prompt_length + b * block_size; block_end = min(prompt_length + (b + 1) * block_size, max_length) | |
attention_mask[:, :, block_start:block_end, :prompt_length] = 0 | |
for prev_b in range(b): | |
prev_start = prompt_length + prev_b * block_size; prev_end = min(prompt_length + (prev_b + 1) * block_size, max_length) | |
attention_mask[:, :, block_start:block_end, prev_start:prev_end] = 0 | |
attention_mask[:, :, block_start:block_end, block_start:block_end] = 0 | |
return attention_mask | |
def extract_attention_mask(full_mask, start_pos, input_length, cache_length): | |
end_pos = start_pos + input_length; total_length = cache_length + input_length | |
extracted_mask = torch.full((1, 1, input_length, total_length), -torch.inf, device=full_mask.device, dtype=full_mask.dtype) | |
extracted_mask[:, :, :, :cache_length] = full_mask[:, :, start_pos:end_pos, :cache_length] | |
extracted_mask[:, :, :, cache_length:] = full_mask[:, :, start_pos:end_pos, start_pos:end_pos] | |
return extracted_mask | |
def top_p_logits(logits, top_p=None): | |
sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
sorted_indices_to_remove = cumulative_probs > top_p | |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
sorted_indices_to_remove[..., 0] = 0 | |
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device) | |
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove) | |
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min) | |
return logits | |
def top_k_logits(logits, top_k=None): | |
top_k = min(top_k, logits.size(-1)) | |
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | |
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min) | |
return logits | |
def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False): | |
if temperature > 0: logits = logits / temperature | |
if top_p is not None and top_p < 1: logits = top_p_logits(logits, top_p) | |
if top_k is not None: logits = top_k_logits(logits, top_k) | |
probs = torch.softmax(logits, dim=-1) | |
if temperature > 0: | |
try: | |
x0 = dists.Categorical(probs=probs).sample() | |
initial_confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) | |
except: initial_confidence, x0 = probs.max(dim=-1) | |
else: initial_confidence, x0 = probs.max(dim=-1) | |
confidence = initial_confidence.clone() | |
if margin_confidence: | |
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True) | |
confidence = sorted_probs[:, 0] - sorted_probs[:, 1] | |
if neg_entropy: | |
epsilon = 1e-10 | |
confidence = torch.sum(probs * torch.log(probs + epsilon), dim=-1) | |
return confidence, x0, initial_confidence | |
class DreamLoRAInference: | |
# CSS is exactly the same as your original script | |
CSS = """ | |
/* Enhanced modern styling */ | |
.main-container { | |
max-width: 1400px; | |
margin: 0 auto; | |
padding: 20px; | |
} | |
.output-text-container { | |
background: linear-gradient(135deg, #f8fafc 0%, #f1f5f9 100%); | |
border: 2px solid #e2e8f0; | |
border-radius: 12px; | |
padding: 20px; | |
margin: 15px 0; | |
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1); | |
} | |
.output-text-container textarea { | |
background: white !important; | |
border: 1px solid #cbd5e1 !important; | |
border-radius: 8px !important; | |
font-family: 'Inter', 'Segoe UI', sans-serif !important; | |
font-size: 14px !important; | |
line-height: 1.6 !important; | |
padding: 16px !important; | |
box-shadow: inset 0 2px 4px 0 rgba(0, 0, 0, 0.06) !important; | |
} | |
.stats-card { | |
background: linear-gradient(135deg, #ecfdf5 0%, #f0fdf4 100%); | |
border: 2px solid #10b981; | |
border-radius: 12px; | |
padding: 20px; | |
margin: 15px 0; | |
box-shadow: 0 4px 6px -1px rgba(16, 185, 129, 0.1); | |
} | |
.stats-card h3 { | |
color: #065f46; | |
margin-top: 0; | |
margin-bottom: 15px; | |
font-weight: 600; | |
} | |
.stats-grid { | |
display: grid; | |
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); | |
gap: 12px; | |
margin-top: 10px; | |
} | |
.stat-item { | |
background: white; | |
padding: 12px 16px; | |
border-radius: 8px; | |
border-left: 4px solid #10b981; | |
box-shadow: 0 1px 3px 0 rgba(0, 0, 0, 0.1); | |
} | |
.stat-label { | |
font-size: 12px; | |
color: #6b7280; | |
text-transform: uppercase; | |
letter-spacing: 0.5px; | |
margin-bottom: 4px; | |
} | |
.stat-value { | |
font-size: 18px; | |
font-weight: 600; | |
color: #065f46; | |
font-family: 'Monaco', 'Menlo', monospace; | |
} | |
.viz-container { | |
background: linear-gradient(135deg, #fefefe 0%, #f9fafb 100%); | |
border: 2px solid #e5e7eb; | |
border-radius: 12px; | |
padding: 20px; | |
margin: 15px 0; | |
height: 600px; | |
overflow-y: auto; | |
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1); | |
position: relative; | |
} | |
.viz-header { | |
background: linear-gradient(135deg, #3b82f6 0%, #2563eb 100%); | |
color: white; | |
padding: 12px 20px; | |
margin: -20px -20px 20px -20px; | |
border-radius: 12px 12px 0 0; | |
font-weight: 600; | |
font-size: 16px; | |
display: flex; | |
align-items: center; | |
gap: 8px; | |
} | |
.viz-header::before { | |
content: "🎬"; | |
font-size: 18px; | |
} | |
.block-container { | |
display: inline-block; | |
border: 2px solid transparent; | |
border-radius: 10px; | |
padding: 8px; | |
margin: 6px 2px; | |
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1); | |
background: rgba(255, 255, 255, 0.8); | |
backdrop-filter: blur(10px); | |
} | |
.block-updating { | |
border-color: #ff4500 !important; | |
box-shadow: 0 0 20px rgba(255, 69, 0, 0.4); | |
transform: scale(1.02); | |
background: rgba(255, 245, 238, 0.9) !important; | |
} | |
.token { | |
padding: 4px 8px; | |
margin: 2px; | |
border-radius: 6px; | |
display: inline-block; | |
line-height: 1.5; | |
font-family: 'Monaco', 'Menlo', monospace; | |
font-size: 13px; | |
font-weight: 500; | |
transition: all 0.2s ease; | |
} | |
.token:hover { | |
transform: translateY(-1px); | |
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.15); | |
} | |
.token.prompt { background: linear-gradient(135deg, #f1f5f9 0%, #e2e8f0 100%); color: #475569; border: 1px solid #cbd5e1; } | |
.token.gen-0 { background: linear-gradient(135deg, #dbeafe 0%, #bfdbfe 100%); color: #1e40af; border: 1px solid #60a5fa; } | |
.token.gen-1 { background: linear-gradient(135deg, #d1fae5 0%, #a7f3d0 100%); color: #065f46; border: 1px solid #34d399; } | |
.token.gen-2 { background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%); color: #92400e; border: 1px solid #fbbf24; } | |
.token.gen-3 { background: linear-gradient(135deg, #fee2e2 0%, #fecaca 100%); color: #991b1b; border: 1px solid #f87171; } | |
.token.gen-4 { background: linear-gradient(135deg, #e0e7ff 0%, #c7d2fe 100%); color: #3730a3; border: 1px solid #818cf8; } | |
.token.gen-5 { background: linear-gradient(135deg, #f3e8ff 0%, #e9d5ff 100%); color: #6b21a8; border: 1px solid #c084fc; } | |
.token.mask { | |
background: linear-gradient(135deg, #f9fafb 0%, #f3f4f6 100%); | |
color: #9ca3af; | |
border: 2px dashed #d1d5db; | |
animation: pulse 2s infinite; | |
} | |
@keyframes pulse { | |
0%, 100% { opacity: 1; } | |
50% { opacity: 0.6; } | |
} | |
.control-button { | |
background: linear-gradient(135deg, #8b5cf6 0%, #7c3aed 100%) !important; | |
border: none !important; | |
color: white !important; | |
padding: 12px 24px !important; | |
border-radius: 10px !important; | |
font-weight: 600 !important; | |
font-size: 14px !important; | |
box-shadow: 0 4px 6px -1px rgba(139, 92, 246, 0.3) !important; | |
transition: all 0.3s ease !important; | |
display: flex !important; | |
align-items: center !important; | |
gap: 8px !important; | |
margin: 10px 0 !important; | |
} | |
.control-button:hover { | |
transform: translateY(-2px) !important; | |
box-shadow: 0 8px 15px -3px rgba(139, 92, 246, 0.4) !important; | |
} | |
.control-button:active { | |
transform: translateY(0) !important; | |
} | |
.control-button::before { | |
content: "🎮"; | |
font-size: 16px; | |
} | |
.param-card { | |
background: white; | |
border: 1px solid #e5e7eb; | |
border-radius: 10px; | |
padding: 16px; | |
margin: 8px 0; | |
box-shadow: 0 1px 3px 0 rgba(0, 0, 0, 0.1); | |
} | |
.viz-container::-webkit-scrollbar { | |
width: 12px; | |
} | |
.viz-container::-webkit-scrollbar-track { | |
background: #f1f5f9; | |
border-radius: 6px; | |
} | |
.viz-container::-webkit-scrollbar-thumb { | |
background: linear-gradient(135deg, #94a3b8 0%, #64748b 100%); | |
border-radius: 6px; | |
border: 2px solid #f1f5f9; | |
} | |
.viz-container::-webkit-scrollbar-thumb:hover { | |
background: linear-gradient(135deg, #64748b 0%, #475569 100%); | |
} | |
.generating-indicator { | |
display: inline-flex; | |
align-items: center; | |
gap: 8px; | |
color: #6366f1; | |
font-weight: 500; | |
} | |
.generating-indicator::after { | |
content: ""; | |
width: 12px; | |
height: 12px; | |
border: 2px solid #6366f1; | |
border-top: 2px solid transparent; | |
border-radius: 50%; | |
animation: spin 1s linear infinite; | |
} | |
@keyframes spin { | |
0% { transform: rotate(0deg); } | |
100% { transform: rotate(360deg); } | |
} | |
@media (max-width: 768px) { | |
.main-container { | |
padding: 10px; | |
} | |
.stats-grid { | |
grid-template-columns: 1fr; | |
} | |
.viz-container { | |
height: 400px; | |
} | |
} | |
""" | |
def __init__(self, **kwargs): | |
print("Initializing DreamLoRAInference...") | |
# Lazy loading: store config, don't load model yet | |
self.config = kwargs | |
self.model = None | |
self.tokenizer = None | |
if self.config.get("dtype") == "bfloat16" and torch.cuda.is_bf16_supported(): | |
self.target_dtype = torch.bfloat16 | |
elif self.config.get("dtype") == "float16": | |
self.target_dtype = torch.float16 | |
else: | |
self.target_dtype = torch.float32 | |
# Set attributes from config | |
for key, value in kwargs.items(): | |
setattr(self, key, value) | |
print("DreamLoRAInference configured. Model will be loaded on first use.") | |
def _ensure_model_loaded(self): | |
"""Load model and tokenizer if they haven't been loaded yet.""" | |
if self.model is None: | |
print("Loading model and tokenizer for the first time...") | |
self._setup_model(self.config["pretrained_path"], self.config["lora_path"]) | |
print("Model and tokenizer setup complete.") | |
def _setup_model(self, pretrained_path, lora_path): | |
config = LLaDAConfig.from_pretrained(pretrained_path) | |
self.model = LLaDAModelLM.from_pretrained( | |
pretrained_path, | |
config=config, | |
torch_dtype=self.target_dtype, | |
device_map="auto" # Use device_map for auto hardware assignment | |
).eval() | |
self.model = PeftModel.from_pretrained(self.model, lora_path) | |
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path) | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
def _apply_chat_template(self, prompt): | |
chat_history = [{"role": "user", "content": prompt}] | |
return self.tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True) | |
def _update_block_completion_states(self, block_states, decoded_token_threshold): | |
for block_id in sorted(block_states.keys()): | |
if 'total_masks' in block_states[block_id] and block_states[block_id]['total_masks'] > 0: | |
decoded_tokens = block_states[block_id]['total_masks'] - block_states[block_id]['mask_count'] | |
decode_ratio = decoded_tokens / block_states[block_id]['total_masks'] | |
if decode_ratio >= decoded_token_threshold: | |
if (next_block_id := block_id + 1) in block_states: | |
block_states[next_block_id]['is_complete'] = True | |
def _render_visualization_html(self, step: int, x_t: torch.Tensor, block_states: Dict, cache_length: int, updated_block_ids: Set[int]) -> str: | |
# This function is identical to your original one, with HTML escaping. | |
timestamp = int(time.time() * 1000) | |
html_parts = [] | |
html_parts.append('<div class="viz-header">Slow-Motion Generation Process</div>') | |
for block_id in sorted(k for k in block_states.keys() if k > 0): | |
state = block_states[block_id] | |
container_classes = ["block-container"] | |
if block_id in updated_block_ids: container_classes.append("block-updating") | |
html_parts.append(f'<div class="{" ".join(container_classes)}" id="block-{block_id}-{timestamp}">') | |
block_tokens = x_t[0, state['start_pos']:state['end_pos']] | |
for token_id in block_tokens: | |
token_id_int = token_id.item() | |
token_classes = ["token"] | |
if token_id_int == self.mask_token_id: | |
token_str = '░'; token_classes.append("mask") | |
else: | |
token_str = self.tokenizer.decode([token_id_int], skip_special_tokens=False) | |
token_str = token_str.replace('&', '&').replace('<', '<').replace('>', '>') | |
token_classes.append(f"gen-{(block_id - 1) % 6}") | |
html_parts.append(f'<span class="{" ".join(token_classes)}">{token_str}</span>') | |
html_parts.append('</div>') | |
html_parts.append(f'<div class="scroll-anchor" id="viz-anchor-{timestamp}"></div>') | |
# Script part from original for scrolling | |
complete_html = f""" | |
<div class="viz-content" id="viz-content-{timestamp}"> | |
{''.join(html_parts)} | |
</div> | |
<script> | |
(function() {{ | |
const container = document.querySelector('.viz-container'); | |
if (container) {{ container.scrollTop = container.scrollHeight; }} | |
}})(); | |
</script> | |
""" | |
return complete_html | |
def stream_and_capture_for_gradio( | |
self, | |
prompt_text: str, | |
max_new_tokens: int, | |
block_size: int, | |
block_add_threshold: float, | |
decoded_token_threshold: float, | |
skip_threshold: float | |
) -> Iterator[Tuple[str, List[str], str, bool]]: | |
# This is the core generation algorithm, now identical to your original script | |
self._ensure_model_loaded() | |
start_time = time.time() | |
captured_frames: List[str] = [] | |
input_ids = self.tokenizer(self._apply_chat_template(prompt_text), return_tensors="pt").input_ids.to(self.model.device) | |
prompt_length = input_ids.shape[1] | |
full_attention_mask = create_full_block_attention_mask(prompt_length, self.max_length, block_size, self.model.device, self.target_dtype) | |
x_t = input_ids | |
block_states = {0: {'start_pos': 0, 'end_pos': prompt_length, 'mask_count': 0, 'total_masks': prompt_length, 'state': 'to_cache', 'is_complete': True}} | |
past_key_values, current_blocks, step, eos_detected, cache_length = None, 0, 0, False, 0 | |
initial_viz_html = self._render_visualization_html(0, x_t, block_states, 0, set()) | |
captured_frames.append(initial_viz_html) | |
yield "", captured_frames, '<div class="generating-indicator">Initializing generation process...</div>', False | |
while True: | |
step += 1 | |
updated_block_ids: Set[int] = set() | |
if len(block_states) - 1 < (max_new_tokens // block_size) and not eos_detected: | |
last_block_id = max(block_states.keys()) | |
progress = (block_states[last_block_id]['total_masks'] - block_states[last_block_id]['mask_count']) / block_states[last_block_id]['total_masks'] if block_states[last_block_id]['total_masks'] > 0 else 1.0 | |
if progress >= block_add_threshold: | |
new_block_id = last_block_id + 1; new_start_pos = x_t.shape[1] | |
if new_start_pos + block_size <= self.max_length: | |
x_t = torch.cat([x_t, torch.full((1, block_size), self.mask_token_id, device=self.model.device, dtype=torch.long)], dim=1) | |
block_states[new_block_id] = {'start_pos': new_start_pos, 'end_pos': new_start_pos + block_size, 'mask_count': block_size, 'total_masks': block_size, 'state': 'active', 'is_complete': False} | |
current_blocks += 1 | |
self._update_block_completion_states(block_states, decoded_token_threshold) | |
if (x_t == self.mask_token_id).sum() == 0 and current_blocks == 0: break | |
blocks_to_cache = [bid for bid, state in block_states.items() if state['state'] == 'to_cache'] | |
update_kvcache = 0 | |
if blocks_to_cache: | |
start_pos, end_pos = block_states[min(blocks_to_cache)]['start_pos'], block_states[max(blocks_to_cache)]['end_pos'] | |
update_kvcache = end_pos - start_pos; input_seq, process_start_pos = x_t[:, start_pos:], start_pos | |
else: | |
active_blocks = [bid for bid, state in block_states.items() if state['state'] == 'active' and state['start_pos'] >= cache_length] | |
if not active_blocks: break | |
start_pos = min(block_states[bid]['start_pos'] for bid in active_blocks); input_seq, process_start_pos = x_t[:, start_pos:], start_pos | |
if input_seq.shape[1] == 0: break | |
attention_mask = extract_attention_mask(full_attention_mask, process_start_pos, input_seq.shape[1], cache_length) | |
outputs = self.model(input_seq, attention_bias=attention_mask, past_key_values=past_key_values, use_cache=True, update_kvcache=update_kvcache + cache_length) | |
if update_kvcache > 0: | |
past_key_values = outputs.past_key_values | |
for bid in blocks_to_cache: block_states[bid]['state'] = 'in_cache' | |
blocks_to_deactivate = [] | |
for block_id, state in block_states.items(): | |
if state['state'] != 'active': continue | |
block_mask_locs = (x_t[0, state['start_pos']:state['end_pos']] == self.mask_token_id).nonzero().squeeze(-1) | |
if block_mask_locs.numel() == 0: | |
blocks_to_deactivate.append(block_id); continue | |
logit_offset = state['start_pos'] - process_start_pos | |
block_mask_logits = outputs.logits[:, logit_offset + block_mask_locs, :] | |
_, x0, initial_confidence = sample_tokens(block_mask_logits.squeeze(0), self.temperature, self.top_p, self.top_k) | |
all_indices = (initial_confidence > skip_threshold).nonzero().squeeze(-1) | |
if state['is_complete'] and all_indices.numel() == 0 and block_mask_logits.numel() > 0: | |
all_indices = torch.tensor([torch.argmax(initial_confidence)], device=self.model.device) | |
if all_indices.numel() > 0: | |
updated_block_ids.add(block_id) | |
positions_to_update = state['start_pos'] + block_mask_locs[all_indices] | |
x_t[0, positions_to_update] = x0[all_indices]; state['mask_count'] -= all_indices.numel() | |
if self.tokenizer.eos_token_id in x0[all_indices]: eos_detected = True | |
if state['mask_count'] == 0: blocks_to_deactivate.append(block_id) | |
for bid in blocks_to_deactivate: | |
if block_states[bid]['state'] == 'active' and all(block_states.get(i, {}).get('state') != 'active' for i in range(bid)): | |
block_states[bid]['state'] = 'to_cache'; current_blocks -= 1 | |
if update_kvcache > 0: cache_length += update_kvcache | |
generated_ids = x_t[0, prompt_length:] | |
valid_ids = generated_ids[generated_ids != self.mask_token_id] | |
live_text = self.tokenizer.decode(valid_ids, skip_special_tokens=True) | |
current_viz_html = self._render_visualization_html(step, x_t, block_states, cache_length, updated_block_ids) | |
captured_frames.append(current_viz_html) | |
yield live_text, captured_frames, f'<div class="generating-indicator">Generating... Step {step}</div>', False | |
total_time = time.time() - start_time | |
final_generated_ids = x_t[0, prompt_length:] | |
eos_positions = (final_generated_ids == self.tokenizer.eos_token_id).nonzero() | |
if eos_positions.numel() > 0: | |
final_generated_ids = final_generated_ids[:eos_positions[0, 0] + 1] | |
final_text = self.tokenizer.decode(final_generated_ids, skip_special_tokens=True) | |
final_viz_html = self._render_visualization_html(step, x_t, block_states, cache_length, set()) | |
captured_frames.append(final_viz_html) | |
tokens_incl_eos = len(final_generated_ids) | |
tokens_excl_eos = len(final_generated_ids[final_generated_ids != self.tokenizer.eos_token_id]) | |
stats_html = f""" | |
<div class="stats-card"> | |
<h3>✅ Generation Complete!</h3> | |
<div class="stats-grid"> | |
<div class="stat-item"> | |
<div class="stat-label">Total Time</div> | |
<div class="stat-value">{total_time:.2f}s</div> | |
</div> | |
<div class="stat-item"> | |
<div class="stat-label">Tokens (incl. EOS)</div> | |
<div class="stat-value">{tokens_incl_eos}</div> | |
</div> | |
<div class="stat-item"> | |
<div class="stat-label">Tokens (excl. EOS)</div> | |
<div class="stat-value">{tokens_excl_eos}</div> | |
</div> | |
<div class="stat-item"> | |
<div class="stat-label">Tokens/Second</div> | |
<div class="stat-value">{(tokens_incl_eos / total_time):.1f}</div> | |
</div> | |
</div> | |
</div> | |
""" | |
yield final_text, captured_frames, stats_html, True | |
# --- Gradio UI and Event Handlers --- | |
if __name__ == "__main__": | |
# Use Hugging Face Hub model IDs | |
config = { | |
"pretrained_path": "GSAI-ML/LLaDA-8B-Instruct", | |
"lora_path": "SJTU-Deng-Lab/D2F_LLaDA_Instruct_8B_Lora", | |
"dtype": "bfloat16", "max_length": 4096, | |
"temperature": 0.0, "top_p": None, "top_k": None, "mask_token_id": 126336, | |
"sampling_strategy": "default", | |
} | |
set_seed(42) | |
inference_engine = DreamLoRAInference(**config) | |
def animate_visualization(html_frames_list: List[str], delay: float) -> Iterator[str]: | |
if not html_frames_list: | |
yield '<div class="viz-header">No visualization data captured</div>' | |
return | |
for frame in html_frames_list: | |
yield frame | |
time.sleep(delay) | |
# Simplified auto-scroll JS from your original script | |
auto_scroll_js = """ | |
<script> | |
function setupAutoScroll(containerSelector, contentSelector) { | |
const container = document.querySelector(containerSelector); | |
if (!container) return; | |
const observer = new MutationObserver(() => { | |
container.scrollTop = container.scrollHeight; | |
}); | |
observer.observe(container, { | |
childList: true, | |
subtree: true | |
}); | |
} | |
document.addEventListener('DOMContentLoaded', () => { | |
// Use a timeout to ensure Gradio elements are rendered | |
setTimeout(() => { | |
setupAutoScroll('#live-text-output', 'textarea'); | |
setupAutoScroll('.viz-container', '.viz-content'); | |
}, 1500); | |
}); | |
</script> | |
""" | |
with gr.Blocks(css=DreamLoRAInference.CSS, theme=gr.themes.Soft(), title="D2F-LLaDA Visualization") as demo: | |
html_frames_state = gr.State([]) | |
generation_complete_state = gr.State(False) | |
gr.HTML(auto_scroll_js) # Keep the JS injection | |
# The entire UI layout is now identical to your original script | |
with gr.Column(elem_classes=["main-container"]): | |
gr.Markdown("# ✨ D2F: Faster-than-AR Inference for Diffusion LLMs") | |
gr.Markdown( | |
""" | |
[GitHub](https://github.com/zhijie-group/Discrete-Diffusion-Forcing) | [📜 Paper](https://arxiv.org/abs/2508.09192) | [🌐 Blog Post](https://zhijie-group.github.io/Discrete-Diffusion-Forcing/) | [🤗 D2F-LLaDA LoRA](https://huggingface.co/SJTU-Deng-Lab/D2F_LLaDA_Instruct_8B_Lora) | [🤗 D2F-Dream LoRA](https://huggingface.co/SJTU-Deng-Lab/D2F_Dream_Base_7B_Lora) | |
""" | |
) | |
gr.Markdown( | |
""" | |
This demo showcases **Discrete Diffusion Forcing (D2F)**, a novel framework that enables Diffusion Language Models (dLLMs) to achieve faster-than-autoregressive inference speeds for the first time. D2F creates an AR-Diffusion hybrid paradigm that combines the efficiency of KV Caching with inter-block parallel decoding. | |
The model powering this demo is **LLaDA-Instruct-8B**, fine-tuned with our D2F method. Watch its unique block-wise generation in real-time, then replay the process in slow motion to see how it works! | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
prompt_input = gr.Textbox( | |
label="🤔 Enter your question", | |
placeholder="Ask me anything! Try: 'Explain quantum physics' or 'Write a story about...'", | |
lines=4, | |
elem_classes=["param-card"] | |
) | |
with gr.Accordion("⚙️ Advanced Settings", open=False): | |
with gr.Row(): | |
max_new_tokens_slider = gr.Slider(minimum=64, maximum=2048, value=2048, step=64, label="Max Tokens", info="Maximum number of tokens to generate") | |
block_size_slider = gr.Slider(minimum=16, maximum=128, value=32, step=16, label="Block Size", info="Size of each generation block") | |
with gr.Row(): | |
block_add_thresh_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.05, label="Block Add Threshold", info="When to add new blocks") | |
decoded_token_thresh_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Completion Threshold", info="Block completion criteria") | |
with gr.Row(): | |
skip_thresh_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.01, label="Skip Threshold", info="Token selection threshold") | |
delay_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.15, step=0.05, label="Playback Speed", info="Slow-motion playback delay (seconds)") | |
generate_button = gr.Button("🚀 Generate Text", variant="primary", size="lg") | |
with gr.Column(scale=3): | |
with gr.Group(elem_classes=["output-text-container"]): | |
gr.Markdown("### 📝 Generated Text (Real-time)") | |
live_text_output = gr.Textbox(label="", interactive=False, lines=15, show_label=False, placeholder="Generated text will appear here as the AI thinks...", elem_id="live-text-output") | |
stats_output = gr.HTML(elem_id="stats-output") | |
with gr.Row(): | |
with gr.Column(): | |
slowmo_button = gr.Button("🎬 Watch Slow-Motion Generation Process", variant="secondary", size="lg", elem_classes=["control-button"], visible=False, interactive=False) | |
with gr.Group(elem_classes=["viz-container"], visible=False) as viz_group: | |
visualization_output = gr.HTML(label="") | |
# Examples are identical to your original script | |
gr.Examples( | |
examples=[ | |
["A circular swimming pool has a diameter of 8 meters. Calculate the pool's circumference and area. First, explain the relationship between diameter, radius, circumference, and area of a circle, including the role of π in these formulas. Then perform the calculations using π ≈ 3.14159. Next, estimate how much water (in cubic meters) would be needed to fill this pool if it has a uniform depth of 1.5 meters. Finally, calculate how much it would cost to fill this pool if water costs $2.50 per cubic meter. Show all steps and include appropriate units in your answer.", 2048, 32, 0.1, 0.5, 0.9, 0.1], | |
["A movie theater offers a loyalty card that costs $15 and gives a 15% discount on all tickets. If a regular movie ticket costs $10, how many tickets would you need to buy to make the loyalty card worthwhile? First, explain the concept of a break-even point. Then set up an equation to find when the total cost with the card equals the total cost without the card. Solve this equation step by step, showing all your work. Finally, interpret your answer in the context of the problem.", 2048, 32, 0.1, 0.5, 0.9, 0.1], | |
["Solve the equation x² - 6x + 8 = 0. First, explain what a quadratic equation is and why it can have up to two solutions. Then solve this equation using three different methods: factoring, completing the square, and the quadratic formula. For each method, explain the mathematical reasoning behind it, show all steps in detail, and discuss when this particular method is most useful. Finally, verify your solutions by substituting them back into the original equation.", 2048, 32, 0.1, 0.55, 0.9, 0.1], | |
], | |
inputs=[prompt_input, max_new_tokens_slider, block_size_slider, block_add_thresh_slider, decoded_token_thresh_slider, skip_thresh_slider, delay_slider], | |
label="💡 Try these examples" | |
) | |
# Event handling is now identical to your original, correct script | |
def update_slowmo_button_visibility(is_complete): | |
return gr.update(visible=is_complete, interactive=is_complete) | |
def show_visualization(): | |
return gr.update(visible=True) | |
inputs_list = [ | |
prompt_input, max_new_tokens_slider, block_size_slider, | |
block_add_thresh_slider, decoded_token_thresh_slider, skip_thresh_slider | |
] | |
# This is the original, correct event chain | |
generation_event = generate_button.click( | |
fn=lambda: [gr.update(visible=False, interactive=False), gr.update(visible=False), gr.update(value=None), gr.update(value="")], | |
outputs=[slowmo_button, viz_group, stats_output, live_text_output] | |
).then( | |
fn=inference_engine.stream_and_capture_for_gradio, | |
inputs=inputs_list, | |
outputs=[live_text_output, html_frames_state, stats_output, generation_complete_state] | |
).then( | |
fn=update_slowmo_button_visibility, | |
inputs=[generation_complete_state], | |
outputs=[slowmo_button] | |
) | |
slowmo_event = slowmo_button.click( | |
fn=show_visualization, | |
outputs=[viz_group] | |
).then( | |
fn=animate_visualization, | |
inputs=[html_frames_state, delay_slider], | |
outputs=[visualization_output] | |
) | |
demo.queue().launch() |