File size: 25,063 Bytes
929e531
 
1f3850d
 
 
64e3308
1f3850d
 
 
 
 
 
 
 
 
 
7dc91c2
 
 
 
 
 
 
c597a5b
b7ba72c
7dc91c2
c597a5b
 
 
7dc91c2
 
15b0565
7dc91c2
15b0565
7dc91c2
15b0565
7dc91c2
15b0565
 
7dc91c2
15b0565
 
7dc91c2
15b0565
7dc91c2
15b0565
 
ca09329
7dc91c2
15b0565
7dc91c2
15b0565
 
 
 
 
 
 
 
ca09329
 
 
 
 
15b0565
1f3850d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6e16ba
1f3850d
 
 
 
 
 
 
 
 
 
 
 
 
01f9743
 
1f3850d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01f9743
1f3850d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01f9743
1f3850d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01f9743
1f3850d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01f9743
1f3850d
 
01f9743
 
188dd85
1f3850d
 
 
01f9743
 
 
ca09329
78169b4
28a5503
6dd8635
78169b4
929e531
01f9743
c597a5b
f9879d7
 
 
 
 
 
 
 
c597a5b
01f9743
 
1f3850d
 
 
 
 
 
01f9743
1f3850d
 
82c94d8
1f3850d
188dd85
 
4b2b3d0
1f3850d
 
28a5503
 
 
4b2b3d0
c597a5b
28a5503
929e531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f3850d
929e531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f3850d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01f9743
 
1f3850d
 
f9879d7
 
1f3850d
4e3fad4
 
 
 
1f3850d
 
01f9743
1f3850d
 
 
 
01f9743
1f3850d
01f9743
1f3850d
 
 
188dd85
 
f9879d7
2ae4165
7dc91c2
188dd85
 
1f3850d
 
929e531
1f3850d
 
 
 
 
 
 
 
 
 
9328543
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
# dreamon_app.py
## this app is built based on https://huggingface.co/spaces/multimodalart/Dream/blob/main/app.py
import torch
import numpy as np
import gradio as gr
import spaces # Ensure spaces is installed if needed for GPU decorator
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel, AutoConfig
import time
import re
from typing import List, Dict, Tuple, Optional, Any, Iterable # Added Any
import torch.distributions as dists # Added import
import traceback # For better error printing
import random
import gzip
import json
import subprocess
import time
from multiprocessing import Process, Queue

import io
import sys

def unsafe_execute(prompt, completion, suffix, test_case, timeout=3):
    check_program = (prompt
        + completion
        + suffix
        + "\n\n"
        + test_case
    )

    # 重定向标准输出和标准错误
    old_stdout = sys.stdout
    old_stderr = sys.stderr
    new_stdout = io.StringIO()
    new_stderr = io.StringIO()
    sys.stdout = new_stdout
    sys.stderr = new_stderr

    try:
        # 执行代码
        exec(check_program, {})
        output = new_stdout.getvalue().strip()
        error_output = new_stderr.getvalue().strip()
    except Exception as e:
        # 捕获异常并记录堆栈跟踪
        output = ''
        error_output = str(e)
    finally:
        # 恢复标准输出和标准错误
        sys.stdout = old_stdout
        sys.stderr = old_stderr

    # 处理输出
    if output and not error_output:
        return output
    elif not output and not error_output:
        return 'Pass all test cases!'
    else:
        error_lines = error_output.splitlines()
        if error_lines:
            return f'Error: {error_lines[-1]}'
        else:
            return 'Error: Unknown error'
        return f'Error: {error_output}'
def read_problems() -> Dict[str, Dict]:
    benchmark_file = "HumanEval-SingleLineInfilling.jsonl.gz"
    return {task["task_id"]: task for task in stream_jsonl(benchmark_file)}


def stream_jsonl(filename: str) -> Iterable[Dict]:
    """
    Parses each jsonl line and yields it as a dictionary
    """
    if filename.endswith(".gz"):
        with open(filename, "rb") as gzfp:
            with gzip.open(gzfp, "rt") as fp:
                for line in fp:
                    if any(not x.isspace() for x in line):
                        yield json.loads(line)
    else:
        with open(filename, "r") as fp:
            for line in fp:
                if any(not x.isspace() for x in line):
                    yield json.loads(line)
problems = read_problems()
class HFTokenizerWrapper():
    def __init__(self, hf_tokenizer: str) -> None:
        self.tokenizer = hf_tokenizer
        self.bos_id = self.tokenizer.bos_token_id
        self.eos_id = self.tokenizer.eos_token_id
        self.mask_id = self.tokenizer.mask_token_id

    def encode(self, s: str, add_bos: bool = False, add_eos: bool = False):
        tokens = [self.bos_id] * add_bos + self.tokenizer.encode(s) + [self.eos_id] * add_eos
        return tokens

    def decode(self, tokens: List[int], **kwargs):
        return self.tokenizer.decode(tokens, **kwargs)
    
    def get_token_offsets(
        self, text: str, tokens: Optional[List[int]] = None
    ) -> Tuple[List[str], List[int]]:
        """Return the offsets of the tokens in the original text. Only used for evaluation."""
        pass

    def convert_tokens_to_ids(self, tokens):
        return self.tokenizer.convert_tokens_to_ids(tokens)
# --- START: Copied Helper functions from generation_utils.py ---
# [Keep the copied functions: top_p_logits, top_k_logits, sample_tokens]
def top_p_logits(logits, top_p=None):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
    sorted_indices_to_remove = cumulative_probs > top_p
    # Shift the indices to the right to keep the first token above the threshold
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0

    mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
    mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
    logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
    return logits

def top_k_logits(logits, top_k=None):
    top_k = min(top_k, logits.size(-1))  # Safety check
    # Remove all tokens with a probability less than the last token of the top-k
    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
    logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
    return logits

def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):

    if temperature > 0:
        logits = logits / temperature
    if top_p is not None and top_p < 1:
        logits = top_p_logits(logits, top_p)
    if top_k is not None and top_k > 0:
        logits = top_k_logits(logits, top_k)
    probs = torch.softmax(logits, dim=-1)

    if temperature > 0:
        try:
            x0 = dists.Categorical(probs=probs).sample()
            confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
        except:
            confidence, x0 = probs.max(dim=-1)
    else:
        confidence, x0 = probs.max(dim=-1)

    if margin_confidence:
        sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
        # Extract top1 and top2 probabilities
        top1_probs = sorted_probs[:, 0]
        top2_probs = sorted_probs[:, 1]
        # Calculate confidence as top1 - top2
        confidence = top1_probs - top2_probs

    if neg_entropy:
        epsilon = 1e-10
        log_probs = torch.log(probs + epsilon)
        confidence = torch.sum(probs * log_probs, dim=-1)

    return confidence, x0
# --- END: Copied Helper functions ---


# --- Model Loading and Constants ---
# Load model configuration to get special token IDs
model_path = "Dream-org/DreamOn-v0-7B"
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer = HFTokenizerWrapper(tokenizer)
print("Loading model...")
model = AutoModel.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32,
    trust_remote_code=True,
    attn_implementation="sdpa" # Explicitly request SDPA
)
model = model.to(device).eval()
print("Model loaded.")

MASK_TOKEN = '<|mask|>'
MASK_ID = tokenizer.mask_id
EOS_ID = tokenizer.eos_id
try:
    EXPAND_ID = tokenizer.convert_tokens_to_ids('<|expand|>')
except:
    raise ValueError("Cannot determine EXPAND_ID. Check model's tokenizer configuration")


if MASK_ID is None:
    raise ValueError("Cannot determine MASK_ID. Check model's tokenizer configuration.")

SPECIAL_TOKEN_IDS = {EOS_ID, MASK_ID}
try:
    IM_START_ID = tokenizer.convert_tokens_to_ids("<|im_start|>")
    IM_END_ID = tokenizer.convert_tokens_to_ids("<|im_end|>")
    SPECIAL_TOKEN_IDS.add(IM_START_ID)
    SPECIAL_TOKEN_IDS.add(IM_END_ID)
except KeyError:
    print("Warning: <|im_start|> or <|im_end|> not found in tokenizer vocab.")
    IM_START_ID = None
    IM_END_ID = None


# --- Helper Functions ---
def parse_constraints(constraints_text: str) -> Dict[int, List[int]]:
    """ Parses word constraints. """
    constraints = {}
    if not constraints_text: return constraints
    parts = constraints_text.split(',')
    for part in parts:
        part = part.strip()
        if ':' not in part: continue
        pos_str, word = part.split(':', 1)
        try:
            pos = int(pos_str.strip())
            word = word.strip()
            token_ids = []
            if word:
                text_to_encode = (" " + word) if (pos > 0 and not word.startswith(" ")) else word
                token_ids = tokenizer.encode(text_to_encode, add_special_tokens=False)
            if token_ids and pos >= 0: constraints[pos] = token_ids
            elif not token_ids and word: print(f"Warning: Could not tokenize constraint word '{word}'")
        except ValueError: print(f"Warning: Invalid position '{pos_str}' in constraint part '{part}'")
        except Exception as e: print(f"Warning: Error processing constraint '{part}': {e}")
    return constraints

# Removed format_chat_history as the state will now be in the correct format

def apply_constraints_to_state(
    x: torch.Tensor,
    prompt_length: int,
    total_length: int,
    parsed_constraints: Dict[int, List[int]],
    current_step: Optional[int] = None
) -> torch.Tensor:
    """ Applies constraints directly to the state tensor `x`. """
    modified_x = x.clone()
    for rel_pos, word_token_ids in parsed_constraints.items():
        abs_start_pos = prompt_length + rel_pos
        abs_end_pos = abs_start_pos + len(word_token_ids)
        if abs_start_pos < total_length and abs_end_pos <= total_length:
            try:
                constraint_tensor = torch.tensor(word_token_ids, dtype=torch.long, device=modified_x.device)
                modified_x[0, abs_start_pos:abs_end_pos] = constraint_tensor
            except IndexError: print(f"Warning (Step {current_step}): Constraint at {rel_pos} ('{tokenizer.decode(word_token_ids)}') goes out of bounds.")
            except Exception as e: print(f"Warning (Step {current_step}): Failed to apply constraint at {rel_pos}: {e}")
    return modified_x



# --- Core Generation Logic with Live Visualization ---

@spaces.GPU
@torch.no_grad()
def infilling_dream(
    prefix: str,
    suffix: str,
    start_gen_len: int,
    max_gen_len: int,
    expand_budget: int,
    temperature: float,
    top_p: Optional[float],
    top_k: Optional[int],
    alg: str,
    alg_temp: Optional[float],
    visualization_delay: float,
    delete_righthand_eos: bool,
    task_id: str
) -> List[Tuple[str, str]]:
    # ------1. Prepare the input for infilling -----------------
    prefix = prefix
    suffix = suffix
    prefix = tokenizer.encode(prefix, add_bos = True, add_eos = False)
    prefix_len = len(prefix)
    suffix = tokenizer.encode(suffix, add_bos = False, add_eos = True)
    input_ids = prefix + [MASK_ID] * start_gen_len + suffix
    input_ids = torch.LongTensor([input_ids]).to(device)
    max_tokens = input_ids.shape[1] + max_gen_len 
    num_generation_tokens = start_gen_len
    cur_generation_window_length = input_ids.shape[1] - start_gen_len + num_generation_tokens
    x = F.pad(input_ids, (0, max_tokens - input_ids.shape[1]), value = MASK_ID)

    # ------ Visualization Setup
    initial_generated_tokens = input_ids[0, prefix_len: prefix_len + num_generation_tokens]
    #yield vis_data_initial
    yield tokenizer.decode(initial_generated_tokens.tolist()), ''
    time.sleep(visualization_delay)

    # ----2. Step by Step Infilling ----------------------------------------
    for i in range(4 * max_gen_len):

        cur_generation_window_length = input_ids.shape[1] - start_gen_len + num_generation_tokens
        attention_mask = torch.ones([input_ids.shape[0], cur_generation_window_length], dtype = torch.int16).to(input_ids.device)
        attention_mask = F.pad(attention_mask, (0, max_tokens - attention_mask.shape[1]), value = 0)

        mask_index = (x == MASK_ID) & (attention_mask == 1)
        if torch.all(~mask_index[:,:cur_generation_window_length]):
            break
        
        tok_idx = attention_mask.long().cumsum(-1) - 1
        tok_idx.masked_fill_(attention_mask == 0, 1)
        
        attention_mask = torch.logical_and(
            attention_mask.unsqueeze(1).unsqueeze(-2),
            attention_mask.unsqueeze(1).unsqueeze(-1),
        )
        
        output = model(x, attention_mask, tok_idx)
        logits = output.logits
        logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
        logits = logits[mask_index] 

        ## block the logit for expansion when token budget is all used
        if cur_generation_window_length == max_tokens or expand_budget == 0:
            logits[:,EXPAND_ID] -= 1e9
        

        if alg == 'maskgit_plus':
            confidence, x0 = sample_tokens(logits, temperature=temperature, top_p=top_p, top_k=top_k)
        elif alg == 'topk_margin':
            confidence, x0 = sample_tokens(logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
        elif alg == 'entropy':
            confidence, x0 = sample_tokens(logits, temperature=temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
        else:
            raise RuntimeError(f"Unknown alg: {alg}")
        #num_mask_token = mask_index.sum()
        #number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else num_mask_token
        number_transfer_tokens = 1
        if number_transfer_tokens > 0:
            if alg_temp is None or alg_temp == 0:
                _, transfer_index = torch.topk(confidence, number_transfer_tokens)
            else:
                confidence = confidence / alg_temp
                confidence = F.softmax(confidence, dim=-1)
                transfer_index = torch.multinomial(confidence, num_samples=number_transfer_tokens)
            x0_ = torch.zeros_like(x0, device=device, dtype=torch.long) + MASK_ID
            x0_[transfer_index] = x0[transfer_index].clone()
            x[mask_index] = x0_
        # Only process if batch size is 1
        
        if delete_righthand_eos:
            if x.shape[0] != 1:
                raise NotImplementedError
            x_seq = x[0]  # Flatten to 1D: shape [seq_len]

            # Find indices where EOS occurs
            eos_indices = (x_seq == EOS_ID).nonzero(as_tuple=True)

            if len(eos_indices[0]) > 0:
                # Get the first occurrence of EOS
                # mask indices
                
                first_eos_idx = eos_indices[0][0].item()
                position_mask = torch.arange(x_seq.size(0), device=x.device) >= first_eos_idx
                replace_mask = position_mask & mask_index[0]
                # Set all tokens after EOS to eos_id
                x_seq.masked_fill_(replace_mask, EOS_ID)

        #        # Reshape back to original shape (unsqueeze)
                x = x_seq.unsqueeze(0)

        
        ## Visualize Denoise Step
        cur_generated_tokens = x[0, prefix_len: prefix_len + num_generation_tokens]
        cur_tokens =  tokenizer.decode(cur_generated_tokens.tolist())
        ## replace all <|endoftext|> with <|delete|>
        cur_tokens = cur_tokens.replace("<|endoftext|>", "<|delete|>")
        yield cur_tokens, ''
        time.sleep(visualization_delay)

        #  Expansion Step: Check for expand_id and replace with two mask tokens
        expand_indices = (x[0] == EXPAND_ID).nonzero(as_tuple=False).squeeze(1)
        if expand_indices.numel() > 0:
            # Process from right to left to prevent shifting issues
            for idx in sorted(expand_indices.tolist(), reverse=True):
                x = torch.cat((
                    x[:, :idx],
                    torch.tensor([[MASK_ID, MASK_ID]], device=x.device),
                    x[:, idx + 1:]
                ), dim=1)
                num_generation_tokens += 1
                expand_budget -= 1
                # Truncate back to max_tokens if needed
                if x.shape[1] > max_tokens:
                    x = x[:, :max_tokens]
            cur_generated_tokens = x[0, prefix_len: prefix_len + num_generation_tokens]
            vis_data = []
            # [Visualization formatting logic remains the same]
            for j in range(num_generation_tokens):
                current_tok_id = cur_generated_tokens[j].item()
                try:
                    decoded_token = tokenizer.decode([current_tok_id], skip_special_tokens=False)
                    display_token = MASK_TOKEN if current_tok_id == MASK_ID else decoded_token
                except Exception: display_token = f"[ID:{current_tok_id}]"
                color = None; token_to_display = display_token
                if current_tok_id == MASK_ID: color = "#444444"
                else: color = "#6699CC"

                if token_to_display: vis_data.append((token_to_display, color))
            yield tokenizer.decode(cur_generated_tokens.tolist()), ''
            #yield vis_data
            time.sleep(visualization_delay)
        ## detele EOS tokens from middle
        
        # Find indices where EOS occurs
        eos_indices = ((x[0] == EOS_ID) & (mask_index[0] == 1)).nonzero(as_tuple=False).squeeze(1)
        if eos_indices.numel() > 0:
            for idx in sorted(eos_indices.tolist(), reverse=True):
                x = torch.cat((
                    x[:, :idx],
                    x[:, idx + 1:],
                    torch.tensor([[MASK_ID]], device = x.device)
                ), dim = 1)
                num_generation_tokens -= 1

            cur_generated_tokens = x[0, prefix_len: prefix_len + num_generation_tokens]
            yield tokenizer.decode(cur_generated_tokens.tolist()), ''
            time.sleep(visualization_delay)

    generated_code = tokenizer.decode(x[0, prefix_len: prefix_len + num_generation_tokens].tolist())
    yield generated_code, ''

def get_example_input():
    ### this functions samples a case from humaneval-infilling as prefix and suffix
    task_id = random.choice(list(problems.keys()))
    problem = problems[task_id]
    prefix, suffix = problem['prompt'], problem['suffix']
    test_case = problem['test']
    pattern = r'METADATA\s*=\s*\{.*?\}\n\n'
    test_case= re.sub(pattern, '', test_case, flags=re.DOTALL).strip()
    test_case = test_case.replace('def check(candidate):', 'def run_test():')
    test_case = test_case.replace('candidate', problem['entry_point'])
    test_case = test_case + '\n\nrun_test()'
    return prefix, '', suffix, test_case, task_id, ''

def check_result(prompt, completion, suffix, test_case):
    prompt = str(prompt) if prompt is not None else ""
    completion = str(completion) if completion is not None else ""
    suffix = str(suffix) if suffix is not None else ""
    test_case = str(test_case) if test_case is not None else ""
    print('prefix', prompt)
    print('middle', completion)
    print('suffix', suffix)
    print('test', test_case)
    result = unsafe_execute(prompt, completion, suffix, test_case)
    return result



# --- Gradio UI ---
css = '''
.category-legend{display:none}
'''

def create_chatbot_demo():
    with gr.Blocks(css=css) as demo:
        gr.Markdown("# DreamOn: Diffusion Language Models For Code Infilling Beyond Fixed-size Canvas\nClick **Example Prompt** to obtain a prefix and suffix, then click **Generate** to create the code.\n\nClick **Run Test Cases** to verify the correctness of the code. You can also input your own test cases.")
        gr.Markdown(
            "[[Model Card](https://huggingface.co/Dream-org/DreamOn-v0-7B)] "
            "[[Blog](https://hkunlp.github.io/blog/2025/dreamon/)]"
            "[[Github](https://github.com/DreamLM/DreamOn)]"
        )

        with gr.Row():
            sample_btn = gr.Button("Example Prompt")
            generate_btn = gr.Button("Generate", variant="primary")
            check_btn = gr.Button("Run Test Cases")
            clear_btn = gr.Button("Clear")

        with gr.Row():
            with gr.Column():
                # Prefix input
                prefix_input = gr.Textbox(
                    label="Prefix Text",
                    placeholder="Enter the beginning of your text...",
                    lines=2
                )
                
                # Middle generation/visualization area
                output_vis = gr.Textbox(
                    label="Generated Text (Middle)", 
                    lines=2
                )
                
                # Suffix input
                suffix_input = gr.Textbox(
                    label="Suffix Text",
                    placeholder="Enter the end of your text...",
                    lines=2
                )
                
                # Hidden Task ID input
                task_id_input = gr.Textbox(
                    label="Task ID",
                    placeholder="Task ID will be stored here...",
                    visible=False
                )
            
            with gr.Column():
                # Test Case input
                test_case_input = gr.Textbox(
                    label="Test Case",
                    placeholder="Enter your test case here...",
                    lines=2
                )
                
                # Result of execution
                result_output = gr.Textbox(
                    label="Result of Execution",
                    placeholder="Execution result will be shown here...",
                    lines=2
                )
        


        # Generation Settings
        with gr.Accordion("Generation Settings"):
            with gr.Row():
                start_gen_len = gr.Slider(
                    minimum=4,
                    maximum=64,
                    value=4,
                    step=4,
                    label="Initial Generation Length"
                )
                max_gen_len = gr.Slider(
                    minimum=32,
                    maximum=64,
                    value=64,
                    step=8,
                    label="Maximum Generation Length"
                )
            with gr.Row():
                expand_budget = gr.Slider(
                    minimum=0,
                    maximum=256,
                    value=64,
                    step=8,
                    label="Expansion Budget"
                )
                temperature = gr.Slider(
                    minimum=0.0,
                    maximum=1.0,
                    value=0.2,
                    step=0.05,
                    label="Temperature"
                )
            with gr.Row():
                top_p = gr.Slider(
                    minimum=0.0,
                    maximum=1.0,
                    value=0.95,
                    step=0.05,
                    label="Top-P (0 disables)"
                )
            with gr.Row():
                top_k = gr.Slider(
                    minimum=0, 
                    maximum=200, 
                    value=0, 
                    step=5, 
                    label="Top-K (0 disables)")

            with gr.Row():
                alg = gr.Radio(
                    choices=['maskgit_plus', 'topk_margin', 'entropy'],
                    value='entropy',
                    label="Generation Algorithm"
                )
                alg_temp = gr.Slider(
                    minimum=0.0,
                    maximum=1.0,
                    value=0.2,
                    step=0.05,
                    label="Algorithm Temperature"
                )
            with gr.Row():
                visualization_delay = gr.Slider(
                    minimum=0.1,
                    maximum=3,
                    value=0.2,
                    step=0.1,
                    label="Visualization Delay (s)"
                )
                pad_delete_righthand = gr.Checkbox(
                    label="Delete all tokens on the righthand side of <|delete|>",
                    value=True
                )

        # Connect the UI elements
        generation_inputs = [
            prefix_input,
            suffix_input,
            start_gen_len,
            max_gen_len,
            expand_budget,
            temperature,
            top_p,
            top_k,
            alg,
            alg_temp,
            visualization_delay,
            pad_delete_righthand,
            task_id_input
        ]

        test_inputs=[prefix_input, output_vis, suffix_input, test_case_input]

        generate_btn.click(
            fn=lambda: ("Waiting for ZeroGPU...", ''),
            outputs=[output_vis, result_output],
            show_progress="hidden"
        ).then(
            fn=infilling_dream,
            inputs=generation_inputs,
            outputs=[output_vis, result_output],
            show_progress="hidden"
        )

        clear_btn.click(
            lambda: ("", "", "", "", "", ""),  # Clear all inputs and outputs
            inputs=[],
            outputs=[prefix_input, suffix_input, output_vis, test_case_input, result_output, task_id_input],
            queue=False
        )

        check_btn.click(
            fn=check_result,
            inputs=test_inputs,
            outputs=[result_output],
            queue=False
        )

        sample_btn.click(
            fn=get_example_input,
            outputs=[prefix_input, output_vis, suffix_input, test_case_input, task_id_input, result_output],
            queue=False
        )
    return demo



# --- Launch ---
if __name__ == "__main__":
    #test()
    demo = create_chatbot_demo()
    demo.queue().launch(debug=True)