# dreamon_app.py ## this app is built based on https://huggingface.co/spaces/multimodalart/Dream/blob/main/app.py import torch import numpy as np import gradio as gr import spaces # Ensure spaces is installed if needed for GPU decorator import torch.nn.functional as F from transformers import AutoTokenizer, AutoModel, AutoConfig import time import re from typing import List, Dict, Tuple, Optional, Any, Iterable # Added Any import torch.distributions as dists # Added import import traceback # For better error printing import random import gzip import json import subprocess import time from multiprocessing import Process, Queue import io import sys def unsafe_execute(prompt, completion, suffix, test_case, timeout=3): check_program = (prompt + completion + suffix + "\n\n" + test_case ) # 重定向标准输出和标准错误 old_stdout = sys.stdout old_stderr = sys.stderr new_stdout = io.StringIO() new_stderr = io.StringIO() sys.stdout = new_stdout sys.stderr = new_stderr try: # 执行代码 exec(check_program, {}) output = new_stdout.getvalue().strip() error_output = new_stderr.getvalue().strip() except Exception as e: # 捕获异常并记录堆栈跟踪 output = '' error_output = str(e) finally: # 恢复标准输出和标准错误 sys.stdout = old_stdout sys.stderr = old_stderr # 处理输出 if output and not error_output: return output elif not output and not error_output: return 'Pass all test cases!' else: error_lines = error_output.splitlines() if error_lines: return f'Error: {error_lines[-1]}' else: return 'Error: Unknown error' return f'Error: {error_output}' def read_problems() -> Dict[str, Dict]: benchmark_file = "HumanEval-SingleLineInfilling.jsonl.gz" return {task["task_id"]: task for task in stream_jsonl(benchmark_file)} def stream_jsonl(filename: str) -> Iterable[Dict]: """ Parses each jsonl line and yields it as a dictionary """ if filename.endswith(".gz"): with open(filename, "rb") as gzfp: with gzip.open(gzfp, "rt") as fp: for line in fp: if any(not x.isspace() for x in line): yield json.loads(line) else: with open(filename, "r") as fp: for line in fp: if any(not x.isspace() for x in line): yield json.loads(line) problems = read_problems() class HFTokenizerWrapper(): def __init__(self, hf_tokenizer: str) -> None: self.tokenizer = hf_tokenizer self.bos_id = self.tokenizer.bos_token_id self.eos_id = self.tokenizer.eos_token_id self.mask_id = self.tokenizer.mask_token_id def encode(self, s: str, add_bos: bool = False, add_eos: bool = False): tokens = [self.bos_id] * add_bos + self.tokenizer.encode(s) + [self.eos_id] * add_eos return tokens def decode(self, tokens: List[int], **kwargs): return self.tokenizer.decode(tokens, **kwargs) def get_token_offsets( self, text: str, tokens: Optional[List[int]] = None ) -> Tuple[List[str], List[int]]: """Return the offsets of the tokens in the original text. Only used for evaluation.""" pass def convert_tokens_to_ids(self, tokens): return self.tokenizer.convert_tokens_to_ids(tokens) # --- START: Copied Helper functions from generation_utils.py --- # [Keep the copied functions: top_p_logits, top_k_logits, sample_tokens] def top_p_logits(logits, top_p=None): sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device) mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove) logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min) return logits def top_k_logits(logits, top_k=None): top_k = min(top_k, logits.size(-1)) # Safety check # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min) return logits def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False): if temperature > 0: logits = logits / temperature if top_p is not None and top_p < 1: logits = top_p_logits(logits, top_p) if top_k is not None and top_k > 0: logits = top_k_logits(logits, top_k) probs = torch.softmax(logits, dim=-1) if temperature > 0: try: x0 = dists.Categorical(probs=probs).sample() confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) except: confidence, x0 = probs.max(dim=-1) else: confidence, x0 = probs.max(dim=-1) if margin_confidence: sorted_probs, _ = torch.sort(probs, dim=-1, descending=True) # Extract top1 and top2 probabilities top1_probs = sorted_probs[:, 0] top2_probs = sorted_probs[:, 1] # Calculate confidence as top1 - top2 confidence = top1_probs - top2_probs if neg_entropy: epsilon = 1e-10 log_probs = torch.log(probs + epsilon) confidence = torch.sum(probs * log_probs, dim=-1) return confidence, x0 # --- END: Copied Helper functions --- # --- Model Loading and Constants --- # Load model configuration to get special token IDs model_path = "Dream-org/DreamOn-v0-7B" config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Using device: {device}") print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = HFTokenizerWrapper(tokenizer) print("Loading model...") model = AutoModel.from_pretrained( model_path, torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32, trust_remote_code=True, attn_implementation="sdpa" # Explicitly request SDPA ) model = model.to(device).eval() print("Model loaded.") MASK_TOKEN = '<|mask|>' MASK_ID = tokenizer.mask_id EOS_ID = tokenizer.eos_id try: EXPAND_ID = tokenizer.convert_tokens_to_ids('<|expand|>') except: raise ValueError("Cannot determine EXPAND_ID. Check model's tokenizer configuration") if MASK_ID is None: raise ValueError("Cannot determine MASK_ID. Check model's tokenizer configuration.") SPECIAL_TOKEN_IDS = {EOS_ID, MASK_ID} try: IM_START_ID = tokenizer.convert_tokens_to_ids("<|im_start|>") IM_END_ID = tokenizer.convert_tokens_to_ids("<|im_end|>") SPECIAL_TOKEN_IDS.add(IM_START_ID) SPECIAL_TOKEN_IDS.add(IM_END_ID) except KeyError: print("Warning: <|im_start|> or <|im_end|> not found in tokenizer vocab.") IM_START_ID = None IM_END_ID = None # --- Helper Functions --- def parse_constraints(constraints_text: str) -> Dict[int, List[int]]: """ Parses word constraints. """ constraints = {} if not constraints_text: return constraints parts = constraints_text.split(',') for part in parts: part = part.strip() if ':' not in part: continue pos_str, word = part.split(':', 1) try: pos = int(pos_str.strip()) word = word.strip() token_ids = [] if word: text_to_encode = (" " + word) if (pos > 0 and not word.startswith(" ")) else word token_ids = tokenizer.encode(text_to_encode, add_special_tokens=False) if token_ids and pos >= 0: constraints[pos] = token_ids elif not token_ids and word: print(f"Warning: Could not tokenize constraint word '{word}'") except ValueError: print(f"Warning: Invalid position '{pos_str}' in constraint part '{part}'") except Exception as e: print(f"Warning: Error processing constraint '{part}': {e}") return constraints # Removed format_chat_history as the state will now be in the correct format def apply_constraints_to_state( x: torch.Tensor, prompt_length: int, total_length: int, parsed_constraints: Dict[int, List[int]], current_step: Optional[int] = None ) -> torch.Tensor: """ Applies constraints directly to the state tensor `x`. """ modified_x = x.clone() for rel_pos, word_token_ids in parsed_constraints.items(): abs_start_pos = prompt_length + rel_pos abs_end_pos = abs_start_pos + len(word_token_ids) if abs_start_pos < total_length and abs_end_pos <= total_length: try: constraint_tensor = torch.tensor(word_token_ids, dtype=torch.long, device=modified_x.device) modified_x[0, abs_start_pos:abs_end_pos] = constraint_tensor except IndexError: print(f"Warning (Step {current_step}): Constraint at {rel_pos} ('{tokenizer.decode(word_token_ids)}') goes out of bounds.") except Exception as e: print(f"Warning (Step {current_step}): Failed to apply constraint at {rel_pos}: {e}") return modified_x # --- Core Generation Logic with Live Visualization --- @spaces.GPU @torch.no_grad() def infilling_dream( prefix: str, suffix: str, start_gen_len: int, max_gen_len: int, expand_budget: int, temperature: float, top_p: Optional[float], top_k: Optional[int], alg: str, alg_temp: Optional[float], visualization_delay: float, delete_righthand_eos: bool, task_id: str ) -> List[Tuple[str, str]]: # ------1. Prepare the input for infilling ----------------- prefix = prefix suffix = suffix prefix = tokenizer.encode(prefix, add_bos = True, add_eos = False) prefix_len = len(prefix) suffix = tokenizer.encode(suffix, add_bos = False, add_eos = True) input_ids = prefix + [MASK_ID] * start_gen_len + suffix input_ids = torch.LongTensor([input_ids]).to(device) max_tokens = input_ids.shape[1] + max_gen_len num_generation_tokens = start_gen_len cur_generation_window_length = input_ids.shape[1] - start_gen_len + num_generation_tokens x = F.pad(input_ids, (0, max_tokens - input_ids.shape[1]), value = MASK_ID) # ------ Visualization Setup initial_generated_tokens = input_ids[0, prefix_len: prefix_len + num_generation_tokens] #yield vis_data_initial yield tokenizer.decode(initial_generated_tokens.tolist()), '' time.sleep(visualization_delay) # ----2. Step by Step Infilling ---------------------------------------- for i in range(4 * max_gen_len): cur_generation_window_length = input_ids.shape[1] - start_gen_len + num_generation_tokens attention_mask = torch.ones([input_ids.shape[0], cur_generation_window_length], dtype = torch.int16).to(input_ids.device) attention_mask = F.pad(attention_mask, (0, max_tokens - attention_mask.shape[1]), value = 0) mask_index = (x == MASK_ID) & (attention_mask == 1) if torch.all(~mask_index[:,:cur_generation_window_length]): break tok_idx = attention_mask.long().cumsum(-1) - 1 tok_idx.masked_fill_(attention_mask == 0, 1) attention_mask = torch.logical_and( attention_mask.unsqueeze(1).unsqueeze(-2), attention_mask.unsqueeze(1).unsqueeze(-1), ) output = model(x, attention_mask, tok_idx) logits = output.logits logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1) logits = logits[mask_index] ## block the logit for expansion when token budget is all used if cur_generation_window_length == max_tokens or expand_budget == 0: logits[:,EXPAND_ID] -= 1e9 if alg == 'maskgit_plus': confidence, x0 = sample_tokens(logits, temperature=temperature, top_p=top_p, top_k=top_k) elif alg == 'topk_margin': confidence, x0 = sample_tokens(logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True) elif alg == 'entropy': confidence, x0 = sample_tokens(logits, temperature=temperature, top_p=top_p, top_k=top_k, neg_entropy=True) else: raise RuntimeError(f"Unknown alg: {alg}") #num_mask_token = mask_index.sum() #number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else num_mask_token number_transfer_tokens = 1 if number_transfer_tokens > 0: if alg_temp is None or alg_temp == 0: _, transfer_index = torch.topk(confidence, number_transfer_tokens) else: confidence = confidence / alg_temp confidence = F.softmax(confidence, dim=-1) transfer_index = torch.multinomial(confidence, num_samples=number_transfer_tokens) x0_ = torch.zeros_like(x0, device=device, dtype=torch.long) + MASK_ID x0_[transfer_index] = x0[transfer_index].clone() x[mask_index] = x0_ # Only process if batch size is 1 if delete_righthand_eos: if x.shape[0] != 1: raise NotImplementedError x_seq = x[0] # Flatten to 1D: shape [seq_len] # Find indices where EOS occurs eos_indices = (x_seq == EOS_ID).nonzero(as_tuple=True) if len(eos_indices[0]) > 0: # Get the first occurrence of EOS # mask indices first_eos_idx = eos_indices[0][0].item() position_mask = torch.arange(x_seq.size(0), device=x.device) >= first_eos_idx replace_mask = position_mask & mask_index[0] # Set all tokens after EOS to eos_id x_seq.masked_fill_(replace_mask, EOS_ID) # # Reshape back to original shape (unsqueeze) x = x_seq.unsqueeze(0) ## Visualize Denoise Step cur_generated_tokens = x[0, prefix_len: prefix_len + num_generation_tokens] cur_tokens = tokenizer.decode(cur_generated_tokens.tolist()) ## replace all <|endoftext|> with <|delete|> cur_tokens = cur_tokens.replace("<|endoftext|>", "<|delete|>") yield cur_tokens, '' time.sleep(visualization_delay) # Expansion Step: Check for expand_id and replace with two mask tokens expand_indices = (x[0] == EXPAND_ID).nonzero(as_tuple=False).squeeze(1) if expand_indices.numel() > 0: # Process from right to left to prevent shifting issues for idx in sorted(expand_indices.tolist(), reverse=True): x = torch.cat(( x[:, :idx], torch.tensor([[MASK_ID, MASK_ID]], device=x.device), x[:, idx + 1:] ), dim=1) num_generation_tokens += 1 expand_budget -= 1 # Truncate back to max_tokens if needed if x.shape[1] > max_tokens: x = x[:, :max_tokens] cur_generated_tokens = x[0, prefix_len: prefix_len + num_generation_tokens] vis_data = [] # [Visualization formatting logic remains the same] for j in range(num_generation_tokens): current_tok_id = cur_generated_tokens[j].item() try: decoded_token = tokenizer.decode([current_tok_id], skip_special_tokens=False) display_token = MASK_TOKEN if current_tok_id == MASK_ID else decoded_token except Exception: display_token = f"[ID:{current_tok_id}]" color = None; token_to_display = display_token if current_tok_id == MASK_ID: color = "#444444" else: color = "#6699CC" if token_to_display: vis_data.append((token_to_display, color)) yield tokenizer.decode(cur_generated_tokens.tolist()), '' #yield vis_data time.sleep(visualization_delay) ## detele EOS tokens from middle # Find indices where EOS occurs eos_indices = ((x[0] == EOS_ID) & (mask_index[0] == 1)).nonzero(as_tuple=False).squeeze(1) if eos_indices.numel() > 0: for idx in sorted(eos_indices.tolist(), reverse=True): x = torch.cat(( x[:, :idx], x[:, idx + 1:], torch.tensor([[MASK_ID]], device = x.device) ), dim = 1) num_generation_tokens -= 1 cur_generated_tokens = x[0, prefix_len: prefix_len + num_generation_tokens] yield tokenizer.decode(cur_generated_tokens.tolist()), '' time.sleep(visualization_delay) generated_code = tokenizer.decode(x[0, prefix_len: prefix_len + num_generation_tokens].tolist()) yield generated_code, '' def get_example_input(): ### this functions samples a case from humaneval-infilling as prefix and suffix task_id = random.choice(list(problems.keys())) problem = problems[task_id] prefix, suffix = problem['prompt'], problem['suffix'] test_case = problem['test'] pattern = r'METADATA\s*=\s*\{.*?\}\n\n' test_case= re.sub(pattern, '', test_case, flags=re.DOTALL).strip() test_case = test_case.replace('def check(candidate):', 'def run_test():') test_case = test_case.replace('candidate', problem['entry_point']) test_case = test_case + '\n\nrun_test()' return prefix, '', suffix, test_case, task_id, '' def check_result(prompt, completion, suffix, test_case): prompt = str(prompt) if prompt is not None else "" completion = str(completion) if completion is not None else "" suffix = str(suffix) if suffix is not None else "" test_case = str(test_case) if test_case is not None else "" print('prefix', prompt) print('middle', completion) print('suffix', suffix) print('test', test_case) result = unsafe_execute(prompt, completion, suffix, test_case) return result # --- Gradio UI --- css = ''' .category-legend{display:none} ''' def create_chatbot_demo(): with gr.Blocks(css=css) as demo: gr.Markdown("# DreamOn: Diffusion Language Models For Code Infilling Beyond Fixed-size Canvas\nClick **Example Prompt** to obtain a prefix and suffix, then click **Generate** to create the code.\n\nClick **Run Test Cases** to verify the correctness of the code. You can also input your own test cases.") gr.Markdown( "[[Model Card](https://huggingface.co/Dream-org/DreamOn-v0-7B)] " "[[Blog](https://hkunlp.github.io/blog/2025/dreamon/)]" "[[Github](https://github.com/DreamLM/DreamOn)]" ) with gr.Row(): sample_btn = gr.Button("Example Prompt") generate_btn = gr.Button("Generate", variant="primary") check_btn = gr.Button("Run Test Cases") clear_btn = gr.Button("Clear") with gr.Row(): with gr.Column(): # Prefix input prefix_input = gr.Textbox( label="Prefix Text", placeholder="Enter the beginning of your text...", lines=2 ) # Middle generation/visualization area output_vis = gr.Textbox( label="Generated Text (Middle)", lines=2 ) # Suffix input suffix_input = gr.Textbox( label="Suffix Text", placeholder="Enter the end of your text...", lines=2 ) # Hidden Task ID input task_id_input = gr.Textbox( label="Task ID", placeholder="Task ID will be stored here...", visible=False ) with gr.Column(): # Test Case input test_case_input = gr.Textbox( label="Test Case", placeholder="Enter your test case here...", lines=2 ) # Result of execution result_output = gr.Textbox( label="Result of Execution", placeholder="Execution result will be shown here...", lines=2 ) # Generation Settings with gr.Accordion("Generation Settings"): with gr.Row(): start_gen_len = gr.Slider( minimum=4, maximum=64, value=4, step=4, label="Initial Generation Length" ) max_gen_len = gr.Slider( minimum=32, maximum=64, value=64, step=8, label="Maximum Generation Length" ) with gr.Row(): expand_budget = gr.Slider( minimum=0, maximum=256, value=64, step=8, label="Expansion Budget" ) temperature = gr.Slider( minimum=0.0, maximum=1.0, value=0.2, step=0.05, label="Temperature" ) with gr.Row(): top_p = gr.Slider( minimum=0.0, maximum=1.0, value=0.95, step=0.05, label="Top-P (0 disables)" ) with gr.Row(): top_k = gr.Slider( minimum=0, maximum=200, value=0, step=5, label="Top-K (0 disables)") with gr.Row(): alg = gr.Radio( choices=['maskgit_plus', 'topk_margin', 'entropy'], value='entropy', label="Generation Algorithm" ) alg_temp = gr.Slider( minimum=0.0, maximum=1.0, value=0.2, step=0.05, label="Algorithm Temperature" ) with gr.Row(): visualization_delay = gr.Slider( minimum=0.1, maximum=3, value=0.2, step=0.1, label="Visualization Delay (s)" ) pad_delete_righthand = gr.Checkbox( label="Delete all tokens on the righthand side of <|delete|>", value=True ) # Connect the UI elements generation_inputs = [ prefix_input, suffix_input, start_gen_len, max_gen_len, expand_budget, temperature, top_p, top_k, alg, alg_temp, visualization_delay, pad_delete_righthand, task_id_input ] test_inputs=[prefix_input, output_vis, suffix_input, test_case_input] generate_btn.click( fn=lambda: ("Waiting for ZeroGPU...", ''), outputs=[output_vis, result_output], show_progress="hidden" ).then( fn=infilling_dream, inputs=generation_inputs, outputs=[output_vis, result_output], show_progress="hidden" ) clear_btn.click( lambda: ("", "", "", "", "", ""), # Clear all inputs and outputs inputs=[], outputs=[prefix_input, suffix_input, output_vis, test_case_input, result_output, task_id_input], queue=False ) check_btn.click( fn=check_result, inputs=test_inputs, outputs=[result_output], queue=False ) sample_btn.click( fn=get_example_input, outputs=[prefix_input, output_vis, suffix_input, test_case_input, task_id_input, result_output], queue=False ) return demo # --- Launch --- if __name__ == "__main__": #test() demo = create_chatbot_demo() demo.queue().launch(debug=True)