DreamOn-v0-7B / app.py
root
update app
4e3fad4
# dreamon_app.py
## this app is built based on https://huggingface.co/spaces/multimodalart/Dream/blob/main/app.py
import torch
import numpy as np
import gradio as gr
import spaces # Ensure spaces is installed if needed for GPU decorator
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel, AutoConfig
import time
import re
from typing import List, Dict, Tuple, Optional, Any, Iterable # Added Any
import torch.distributions as dists # Added import
import traceback # For better error printing
import random
import gzip
import json
import subprocess
import time
from multiprocessing import Process, Queue
import io
import sys
def unsafe_execute(prompt, completion, suffix, test_case, timeout=3):
check_program = (prompt
+ completion
+ suffix
+ "\n\n"
+ test_case
)
# 重定向标准输出和标准错误
old_stdout = sys.stdout
old_stderr = sys.stderr
new_stdout = io.StringIO()
new_stderr = io.StringIO()
sys.stdout = new_stdout
sys.stderr = new_stderr
try:
# 执行代码
exec(check_program, {})
output = new_stdout.getvalue().strip()
error_output = new_stderr.getvalue().strip()
except Exception as e:
# 捕获异常并记录堆栈跟踪
output = ''
error_output = str(e)
finally:
# 恢复标准输出和标准错误
sys.stdout = old_stdout
sys.stderr = old_stderr
# 处理输出
if output and not error_output:
return output
elif not output and not error_output:
return 'Pass all test cases!'
else:
error_lines = error_output.splitlines()
if error_lines:
return f'Error: {error_lines[-1]}'
else:
return 'Error: Unknown error'
return f'Error: {error_output}'
def read_problems() -> Dict[str, Dict]:
benchmark_file = "HumanEval-SingleLineInfilling.jsonl.gz"
return {task["task_id"]: task for task in stream_jsonl(benchmark_file)}
def stream_jsonl(filename: str) -> Iterable[Dict]:
"""
Parses each jsonl line and yields it as a dictionary
"""
if filename.endswith(".gz"):
with open(filename, "rb") as gzfp:
with gzip.open(gzfp, "rt") as fp:
for line in fp:
if any(not x.isspace() for x in line):
yield json.loads(line)
else:
with open(filename, "r") as fp:
for line in fp:
if any(not x.isspace() for x in line):
yield json.loads(line)
problems = read_problems()
class HFTokenizerWrapper():
def __init__(self, hf_tokenizer: str) -> None:
self.tokenizer = hf_tokenizer
self.bos_id = self.tokenizer.bos_token_id
self.eos_id = self.tokenizer.eos_token_id
self.mask_id = self.tokenizer.mask_token_id
def encode(self, s: str, add_bos: bool = False, add_eos: bool = False):
tokens = [self.bos_id] * add_bos + self.tokenizer.encode(s) + [self.eos_id] * add_eos
return tokens
def decode(self, tokens: List[int], **kwargs):
return self.tokenizer.decode(tokens, **kwargs)
def get_token_offsets(
self, text: str, tokens: Optional[List[int]] = None
) -> Tuple[List[str], List[int]]:
"""Return the offsets of the tokens in the original text. Only used for evaluation."""
pass
def convert_tokens_to_ids(self, tokens):
return self.tokenizer.convert_tokens_to_ids(tokens)
# --- START: Copied Helper functions from generation_utils.py ---
# [Keep the copied functions: top_p_logits, top_k_logits, sample_tokens]
def top_p_logits(logits, top_p=None):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
return logits
def top_k_logits(logits, top_k=None):
top_k = min(top_k, logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
return logits
def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
if temperature > 0:
logits = logits / temperature
if top_p is not None and top_p < 1:
logits = top_p_logits(logits, top_p)
if top_k is not None and top_k > 0:
logits = top_k_logits(logits, top_k)
probs = torch.softmax(logits, dim=-1)
if temperature > 0:
try:
x0 = dists.Categorical(probs=probs).sample()
confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
except:
confidence, x0 = probs.max(dim=-1)
else:
confidence, x0 = probs.max(dim=-1)
if margin_confidence:
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
# Extract top1 and top2 probabilities
top1_probs = sorted_probs[:, 0]
top2_probs = sorted_probs[:, 1]
# Calculate confidence as top1 - top2
confidence = top1_probs - top2_probs
if neg_entropy:
epsilon = 1e-10
log_probs = torch.log(probs + epsilon)
confidence = torch.sum(probs * log_probs, dim=-1)
return confidence, x0
# --- END: Copied Helper functions ---
# --- Model Loading and Constants ---
# Load model configuration to get special token IDs
model_path = "Dream-org/DreamOn-v0-7B"
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer = HFTokenizerWrapper(tokenizer)
print("Loading model...")
model = AutoModel.from_pretrained(
model_path,
torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32,
trust_remote_code=True,
attn_implementation="sdpa" # Explicitly request SDPA
)
model = model.to(device).eval()
print("Model loaded.")
MASK_TOKEN = '<|mask|>'
MASK_ID = tokenizer.mask_id
EOS_ID = tokenizer.eos_id
try:
EXPAND_ID = tokenizer.convert_tokens_to_ids('<|expand|>')
except:
raise ValueError("Cannot determine EXPAND_ID. Check model's tokenizer configuration")
if MASK_ID is None:
raise ValueError("Cannot determine MASK_ID. Check model's tokenizer configuration.")
SPECIAL_TOKEN_IDS = {EOS_ID, MASK_ID}
try:
IM_START_ID = tokenizer.convert_tokens_to_ids("<|im_start|>")
IM_END_ID = tokenizer.convert_tokens_to_ids("<|im_end|>")
SPECIAL_TOKEN_IDS.add(IM_START_ID)
SPECIAL_TOKEN_IDS.add(IM_END_ID)
except KeyError:
print("Warning: <|im_start|> or <|im_end|> not found in tokenizer vocab.")
IM_START_ID = None
IM_END_ID = None
# --- Helper Functions ---
def parse_constraints(constraints_text: str) -> Dict[int, List[int]]:
""" Parses word constraints. """
constraints = {}
if not constraints_text: return constraints
parts = constraints_text.split(',')
for part in parts:
part = part.strip()
if ':' not in part: continue
pos_str, word = part.split(':', 1)
try:
pos = int(pos_str.strip())
word = word.strip()
token_ids = []
if word:
text_to_encode = (" " + word) if (pos > 0 and not word.startswith(" ")) else word
token_ids = tokenizer.encode(text_to_encode, add_special_tokens=False)
if token_ids and pos >= 0: constraints[pos] = token_ids
elif not token_ids and word: print(f"Warning: Could not tokenize constraint word '{word}'")
except ValueError: print(f"Warning: Invalid position '{pos_str}' in constraint part '{part}'")
except Exception as e: print(f"Warning: Error processing constraint '{part}': {e}")
return constraints
# Removed format_chat_history as the state will now be in the correct format
def apply_constraints_to_state(
x: torch.Tensor,
prompt_length: int,
total_length: int,
parsed_constraints: Dict[int, List[int]],
current_step: Optional[int] = None
) -> torch.Tensor:
""" Applies constraints directly to the state tensor `x`. """
modified_x = x.clone()
for rel_pos, word_token_ids in parsed_constraints.items():
abs_start_pos = prompt_length + rel_pos
abs_end_pos = abs_start_pos + len(word_token_ids)
if abs_start_pos < total_length and abs_end_pos <= total_length:
try:
constraint_tensor = torch.tensor(word_token_ids, dtype=torch.long, device=modified_x.device)
modified_x[0, abs_start_pos:abs_end_pos] = constraint_tensor
except IndexError: print(f"Warning (Step {current_step}): Constraint at {rel_pos} ('{tokenizer.decode(word_token_ids)}') goes out of bounds.")
except Exception as e: print(f"Warning (Step {current_step}): Failed to apply constraint at {rel_pos}: {e}")
return modified_x
# --- Core Generation Logic with Live Visualization ---
@spaces.GPU
@torch.no_grad()
def infilling_dream(
prefix: str,
suffix: str,
start_gen_len: int,
max_gen_len: int,
expand_budget: int,
temperature: float,
top_p: Optional[float],
top_k: Optional[int],
alg: str,
alg_temp: Optional[float],
visualization_delay: float,
delete_righthand_eos: bool,
task_id: str
) -> List[Tuple[str, str]]:
# ------1. Prepare the input for infilling -----------------
prefix = prefix
suffix = suffix
prefix = tokenizer.encode(prefix, add_bos = True, add_eos = False)
prefix_len = len(prefix)
suffix = tokenizer.encode(suffix, add_bos = False, add_eos = True)
input_ids = prefix + [MASK_ID] * start_gen_len + suffix
input_ids = torch.LongTensor([input_ids]).to(device)
max_tokens = input_ids.shape[1] + max_gen_len
num_generation_tokens = start_gen_len
cur_generation_window_length = input_ids.shape[1] - start_gen_len + num_generation_tokens
x = F.pad(input_ids, (0, max_tokens - input_ids.shape[1]), value = MASK_ID)
# ------ Visualization Setup
initial_generated_tokens = input_ids[0, prefix_len: prefix_len + num_generation_tokens]
#yield vis_data_initial
yield tokenizer.decode(initial_generated_tokens.tolist()), ''
time.sleep(visualization_delay)
# ----2. Step by Step Infilling ----------------------------------------
for i in range(4 * max_gen_len):
cur_generation_window_length = input_ids.shape[1] - start_gen_len + num_generation_tokens
attention_mask = torch.ones([input_ids.shape[0], cur_generation_window_length], dtype = torch.int16).to(input_ids.device)
attention_mask = F.pad(attention_mask, (0, max_tokens - attention_mask.shape[1]), value = 0)
mask_index = (x == MASK_ID) & (attention_mask == 1)
if torch.all(~mask_index[:,:cur_generation_window_length]):
break
tok_idx = attention_mask.long().cumsum(-1) - 1
tok_idx.masked_fill_(attention_mask == 0, 1)
attention_mask = torch.logical_and(
attention_mask.unsqueeze(1).unsqueeze(-2),
attention_mask.unsqueeze(1).unsqueeze(-1),
)
output = model(x, attention_mask, tok_idx)
logits = output.logits
logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
logits = logits[mask_index]
## block the logit for expansion when token budget is all used
if cur_generation_window_length == max_tokens or expand_budget == 0:
logits[:,EXPAND_ID] -= 1e9
if alg == 'maskgit_plus':
confidence, x0 = sample_tokens(logits, temperature=temperature, top_p=top_p, top_k=top_k)
elif alg == 'topk_margin':
confidence, x0 = sample_tokens(logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
elif alg == 'entropy':
confidence, x0 = sample_tokens(logits, temperature=temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
else:
raise RuntimeError(f"Unknown alg: {alg}")
#num_mask_token = mask_index.sum()
#number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else num_mask_token
number_transfer_tokens = 1
if number_transfer_tokens > 0:
if alg_temp is None or alg_temp == 0:
_, transfer_index = torch.topk(confidence, number_transfer_tokens)
else:
confidence = confidence / alg_temp
confidence = F.softmax(confidence, dim=-1)
transfer_index = torch.multinomial(confidence, num_samples=number_transfer_tokens)
x0_ = torch.zeros_like(x0, device=device, dtype=torch.long) + MASK_ID
x0_[transfer_index] = x0[transfer_index].clone()
x[mask_index] = x0_
# Only process if batch size is 1
if delete_righthand_eos:
if x.shape[0] != 1:
raise NotImplementedError
x_seq = x[0] # Flatten to 1D: shape [seq_len]
# Find indices where EOS occurs
eos_indices = (x_seq == EOS_ID).nonzero(as_tuple=True)
if len(eos_indices[0]) > 0:
# Get the first occurrence of EOS
# mask indices
first_eos_idx = eos_indices[0][0].item()
position_mask = torch.arange(x_seq.size(0), device=x.device) >= first_eos_idx
replace_mask = position_mask & mask_index[0]
# Set all tokens after EOS to eos_id
x_seq.masked_fill_(replace_mask, EOS_ID)
# # Reshape back to original shape (unsqueeze)
x = x_seq.unsqueeze(0)
## Visualize Denoise Step
cur_generated_tokens = x[0, prefix_len: prefix_len + num_generation_tokens]
cur_tokens = tokenizer.decode(cur_generated_tokens.tolist())
## replace all <|endoftext|> with <|delete|>
cur_tokens = cur_tokens.replace("<|endoftext|>", "<|delete|>")
yield cur_tokens, ''
time.sleep(visualization_delay)
# Expansion Step: Check for expand_id and replace with two mask tokens
expand_indices = (x[0] == EXPAND_ID).nonzero(as_tuple=False).squeeze(1)
if expand_indices.numel() > 0:
# Process from right to left to prevent shifting issues
for idx in sorted(expand_indices.tolist(), reverse=True):
x = torch.cat((
x[:, :idx],
torch.tensor([[MASK_ID, MASK_ID]], device=x.device),
x[:, idx + 1:]
), dim=1)
num_generation_tokens += 1
expand_budget -= 1
# Truncate back to max_tokens if needed
if x.shape[1] > max_tokens:
x = x[:, :max_tokens]
cur_generated_tokens = x[0, prefix_len: prefix_len + num_generation_tokens]
vis_data = []
# [Visualization formatting logic remains the same]
for j in range(num_generation_tokens):
current_tok_id = cur_generated_tokens[j].item()
try:
decoded_token = tokenizer.decode([current_tok_id], skip_special_tokens=False)
display_token = MASK_TOKEN if current_tok_id == MASK_ID else decoded_token
except Exception: display_token = f"[ID:{current_tok_id}]"
color = None; token_to_display = display_token
if current_tok_id == MASK_ID: color = "#444444"
else: color = "#6699CC"
if token_to_display: vis_data.append((token_to_display, color))
yield tokenizer.decode(cur_generated_tokens.tolist()), ''
#yield vis_data
time.sleep(visualization_delay)
## detele EOS tokens from middle
# Find indices where EOS occurs
eos_indices = ((x[0] == EOS_ID) & (mask_index[0] == 1)).nonzero(as_tuple=False).squeeze(1)
if eos_indices.numel() > 0:
for idx in sorted(eos_indices.tolist(), reverse=True):
x = torch.cat((
x[:, :idx],
x[:, idx + 1:],
torch.tensor([[MASK_ID]], device = x.device)
), dim = 1)
num_generation_tokens -= 1
cur_generated_tokens = x[0, prefix_len: prefix_len + num_generation_tokens]
yield tokenizer.decode(cur_generated_tokens.tolist()), ''
time.sleep(visualization_delay)
generated_code = tokenizer.decode(x[0, prefix_len: prefix_len + num_generation_tokens].tolist())
yield generated_code, ''
def get_example_input():
### this functions samples a case from humaneval-infilling as prefix and suffix
task_id = random.choice(list(problems.keys()))
problem = problems[task_id]
prefix, suffix = problem['prompt'], problem['suffix']
test_case = problem['test']
pattern = r'METADATA\s*=\s*\{.*?\}\n\n'
test_case= re.sub(pattern, '', test_case, flags=re.DOTALL).strip()
test_case = test_case.replace('def check(candidate):', 'def run_test():')
test_case = test_case.replace('candidate', problem['entry_point'])
test_case = test_case + '\n\nrun_test()'
return prefix, '', suffix, test_case, task_id, ''
def check_result(prompt, completion, suffix, test_case):
prompt = str(prompt) if prompt is not None else ""
completion = str(completion) if completion is not None else ""
suffix = str(suffix) if suffix is not None else ""
test_case = str(test_case) if test_case is not None else ""
print('prefix', prompt)
print('middle', completion)
print('suffix', suffix)
print('test', test_case)
result = unsafe_execute(prompt, completion, suffix, test_case)
return result
# --- Gradio UI ---
css = '''
.category-legend{display:none}
'''
def create_chatbot_demo():
with gr.Blocks(css=css) as demo:
gr.Markdown("# DreamOn: Diffusion Language Models For Code Infilling Beyond Fixed-size Canvas\nClick **Example Prompt** to obtain a prefix and suffix, then click **Generate** to create the code.\n\nClick **Run Test Cases** to verify the correctness of the code. You can also input your own test cases.")
gr.Markdown(
"[[Model Card](https://huggingface.co/Dream-org/DreamOn-v0-7B)] "
"[[Blog](https://hkunlp.github.io/blog/2025/dreamon/)]"
"[[Github](https://github.com/DreamLM/DreamOn)]"
)
with gr.Row():
sample_btn = gr.Button("Example Prompt")
generate_btn = gr.Button("Generate", variant="primary")
check_btn = gr.Button("Run Test Cases")
clear_btn = gr.Button("Clear")
with gr.Row():
with gr.Column():
# Prefix input
prefix_input = gr.Textbox(
label="Prefix Text",
placeholder="Enter the beginning of your text...",
lines=2
)
# Middle generation/visualization area
output_vis = gr.Textbox(
label="Generated Text (Middle)",
lines=2
)
# Suffix input
suffix_input = gr.Textbox(
label="Suffix Text",
placeholder="Enter the end of your text...",
lines=2
)
# Hidden Task ID input
task_id_input = gr.Textbox(
label="Task ID",
placeholder="Task ID will be stored here...",
visible=False
)
with gr.Column():
# Test Case input
test_case_input = gr.Textbox(
label="Test Case",
placeholder="Enter your test case here...",
lines=2
)
# Result of execution
result_output = gr.Textbox(
label="Result of Execution",
placeholder="Execution result will be shown here...",
lines=2
)
# Generation Settings
with gr.Accordion("Generation Settings"):
with gr.Row():
start_gen_len = gr.Slider(
minimum=4,
maximum=64,
value=4,
step=4,
label="Initial Generation Length"
)
max_gen_len = gr.Slider(
minimum=32,
maximum=64,
value=64,
step=8,
label="Maximum Generation Length"
)
with gr.Row():
expand_budget = gr.Slider(
minimum=0,
maximum=256,
value=64,
step=8,
label="Expansion Budget"
)
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.2,
step=0.05,
label="Temperature"
)
with gr.Row():
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-P (0 disables)"
)
with gr.Row():
top_k = gr.Slider(
minimum=0,
maximum=200,
value=0,
step=5,
label="Top-K (0 disables)")
with gr.Row():
alg = gr.Radio(
choices=['maskgit_plus', 'topk_margin', 'entropy'],
value='entropy',
label="Generation Algorithm"
)
alg_temp = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.2,
step=0.05,
label="Algorithm Temperature"
)
with gr.Row():
visualization_delay = gr.Slider(
minimum=0.1,
maximum=3,
value=0.2,
step=0.1,
label="Visualization Delay (s)"
)
pad_delete_righthand = gr.Checkbox(
label="Delete all tokens on the righthand side of <|delete|>",
value=True
)
# Connect the UI elements
generation_inputs = [
prefix_input,
suffix_input,
start_gen_len,
max_gen_len,
expand_budget,
temperature,
top_p,
top_k,
alg,
alg_temp,
visualization_delay,
pad_delete_righthand,
task_id_input
]
test_inputs=[prefix_input, output_vis, suffix_input, test_case_input]
generate_btn.click(
fn=lambda: ("Waiting for ZeroGPU...", ''),
outputs=[output_vis, result_output],
show_progress="hidden"
).then(
fn=infilling_dream,
inputs=generation_inputs,
outputs=[output_vis, result_output],
show_progress="hidden"
)
clear_btn.click(
lambda: ("", "", "", "", "", ""), # Clear all inputs and outputs
inputs=[],
outputs=[prefix_input, suffix_input, output_vis, test_case_input, result_output, task_id_input],
queue=False
)
check_btn.click(
fn=check_result,
inputs=test_inputs,
outputs=[result_output],
queue=False
)
sample_btn.click(
fn=get_example_input,
outputs=[prefix_input, output_vis, suffix_input, test_case_input, task_id_input, result_output],
queue=False
)
return demo
# --- Launch ---
if __name__ == "__main__":
#test()
demo = create_chatbot_demo()
demo.queue().launch(debug=True)