root commited on
Commit
01f9743
·
1 Parent(s): 3b71bd2

update testcase

Browse files
Files changed (2) hide show
  1. app.py +52 -61
  2. eval.py +201 -0
app.py CHANGED
@@ -13,7 +13,7 @@ import traceback # For better error printing
13
  import random
14
  import gzip
15
  import json
16
-
17
 
18
  def read_problems() -> Dict[str, Dict]:
19
  benchmark_file = "HumanEval-SingleLineInfilling.jsonl.gz"
@@ -223,7 +223,8 @@ def infilling_dream(
223
  alg: str,
224
  alg_temp: Optional[float],
225
  visualization_delay: float,
226
- delete_righthand_eos: bool
 
227
  ) -> List[Tuple[str, str]]:
228
  # ------1. Prepare the input for infilling -----------------
229
  prefix = prefix
@@ -249,7 +250,7 @@ def infilling_dream(
249
 
250
  previous_tokens_vis = initial_generated_tokens
251
  #yield vis_data_initial
252
- yield tokenizer.decode(initial_generated_tokens.tolist())
253
  time.sleep(visualization_delay)
254
 
255
  # ----2. Step by Step Infilling ----------------------------------------
@@ -357,7 +358,7 @@ def infilling_dream(
357
  cur_tokens = tokenizer.decode(cur_generated_tokens.tolist())
358
  ## replace all <|endoftext|> with <|delete|>
359
  cur_tokens = cur_tokens.replace("<|endoftext|>", "<|delete|>")
360
- yield cur_tokens
361
  time.sleep(visualization_delay)
362
 
363
  # Expansion Step: Check for expand_id and replace with two mask tokens
@@ -389,7 +390,7 @@ def infilling_dream(
389
  else: color = "#6699CC"
390
 
391
  if token_to_display: vis_data.append((token_to_display, color))
392
- yield tokenizer.decode(cur_generated_tokens.tolist())
393
  #yield vis_data
394
  time.sleep(visualization_delay)
395
  ## detele EOS tokens from middle
@@ -418,23 +419,36 @@ def infilling_dream(
418
  # else: color = "#6699CC"
419
 
420
  # vis_data.append((token_to_display, color))
421
- yield tokenizer.decode(cur_generated_tokens.tolist())
422
  #yield vis_data
423
  time.sleep(visualization_delay)
424
 
425
- yield tokenizer.decode(x[0, prefix_len: prefix_len + num_generation_tokens].tolist())
 
 
 
426
  def get_example_input():
427
  ### this functions samples a case from humaneval-infilling as prefix and suffix
428
  task_id = random.choice(list(problems.keys()))
429
- example = problems[task_id]
430
- prefix, suffix = example['prompt'], example['suffix']
431
- return prefix, suffix, ''
 
 
 
 
 
 
 
 
 
432
 
433
 
434
  # --- Gradio UI ---
435
  css = '''
436
  .category-legend{display:none}
437
  '''
 
438
  def create_chatbot_demo():
439
  with gr.Blocks(css=css) as demo:
440
  gr.Markdown("# DreamOn: Diffusion Language Models For Code Infilling Beyond Fixed-size Canvas\nClick **Example Prompt** to get a prefix and suffix, then click **Generate** to generate code. Have fun!")
@@ -464,6 +478,27 @@ def create_chatbot_demo():
464
  lines=2
465
  )
466
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
  with gr.Row():
468
  sample_btn = gr.Button("Example Prompt")
469
 
@@ -512,12 +547,6 @@ def create_chatbot_demo():
512
  label="Top-P (0 disables)"
513
  )
514
  with gr.Row():
515
- top_p = gr.Slider(
516
- minimum=0.0,
517
- maximum=1.0,
518
- value=0.95,
519
- step=0.05,
520
- label="Top-P (0 disables)")
521
  top_k = gr.Slider(
522
  minimum=0,
523
  maximum=200,
@@ -525,7 +554,6 @@ def create_chatbot_demo():
525
  step=5,
526
  label="Top-K (0 disables)")
527
 
528
-
529
  with gr.Row():
530
  alg = gr.Radio(
531
  choices=['maskgit_plus', 'topk_margin', 'entropy'],
@@ -552,7 +580,6 @@ def create_chatbot_demo():
552
  value=True
553
  )
554
 
555
-
556
  # Connect the UI elements
557
  generation_inputs = [
558
  prefix_input,
@@ -566,71 +593,35 @@ def create_chatbot_demo():
566
  alg,
567
  alg_temp,
568
  visualization_delay,
569
- pad_delete_righthand
 
570
  ]
571
 
572
  generate_btn.click(
573
  fn=infilling_dream,
574
  inputs=generation_inputs,
575
- outputs=[output_vis],
576
  show_progress="hidden"
577
  )
578
 
579
  clear_btn.click(
580
- lambda: ("", "", ""), # Clear all inputs and outputs
581
  inputs=[],
582
- outputs=[prefix_input, suffix_input, output_vis],
583
  queue=False
584
  )
585
 
586
  sample_btn.click(
587
  fn=get_example_input,
588
- outputs=[prefix_input, suffix_input, output_vis],
589
  queue=False
590
  )
591
  return demo
592
 
593
- def test():
594
- prefix = '''import List
595
 
596
- def has_close_elements(numbers: List[float], threshold: float) -> bool:
597
- """ Check if in given list of numbers, are any two numbers closer to each other than
598
- given threshold.
599
- >>> has_close_elements([1.0, 2.0, 3.0], 0.5)
600
- False
601
- >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
602
- True
603
- """
604
-
605
- '''
606
-
607
- suffix = '''
608
- for idx2, elem2 in enumerate(numbers):
609
- if idx != idx2:
610
- distance = abs(elem - elem2)
611
- if distance < threshold:
612
- return True
613
-
614
- return False
615
- '''
616
- infilling_dream(
617
- prefix=prefix,
618
- suffix= suffix,
619
- start_gen_len=4,
620
- max_gen_len= 64,
621
- expand_budget= 60,
622
- temperature=0.2,
623
- top_p = 0.9,
624
- top_k = None,
625
- alg= 'entropy',
626
- alg_temp=0,
627
- visualization_delay= 0,
628
- delete_righthand_eos= True
629
- )
630
 
631
  # --- Launch ---
632
  if __name__ == "__main__":
633
  #test()
634
  demo = create_chatbot_demo()
635
- demo.queue().launch(debug=True, share=True)
636
-
 
13
  import random
14
  import gzip
15
  import json
16
+ from eval import unsafe_execute ## args:(problem, completion)
17
 
18
  def read_problems() -> Dict[str, Dict]:
19
  benchmark_file = "HumanEval-SingleLineInfilling.jsonl.gz"
 
223
  alg: str,
224
  alg_temp: Optional[float],
225
  visualization_delay: float,
226
+ delete_righthand_eos: bool,
227
+ task_id: str
228
  ) -> List[Tuple[str, str]]:
229
  # ------1. Prepare the input for infilling -----------------
230
  prefix = prefix
 
250
 
251
  previous_tokens_vis = initial_generated_tokens
252
  #yield vis_data_initial
253
+ yield tokenizer.decode(initial_generated_tokens.tolist()), ''
254
  time.sleep(visualization_delay)
255
 
256
  # ----2. Step by Step Infilling ----------------------------------------
 
358
  cur_tokens = tokenizer.decode(cur_generated_tokens.tolist())
359
  ## replace all <|endoftext|> with <|delete|>
360
  cur_tokens = cur_tokens.replace("<|endoftext|>", "<|delete|>")
361
+ yield cur_tokens, ''
362
  time.sleep(visualization_delay)
363
 
364
  # Expansion Step: Check for expand_id and replace with two mask tokens
 
390
  else: color = "#6699CC"
391
 
392
  if token_to_display: vis_data.append((token_to_display, color))
393
+ yield tokenizer.decode(cur_generated_tokens.tolist()), ''
394
  #yield vis_data
395
  time.sleep(visualization_delay)
396
  ## detele EOS tokens from middle
 
419
  # else: color = "#6699CC"
420
 
421
  # vis_data.append((token_to_display, color))
422
+ yield tokenizer.decode(cur_generated_tokens.tolist()), ''
423
  #yield vis_data
424
  time.sleep(visualization_delay)
425
 
426
+ generated_code = tokenizer.decode(x[0, prefix_len: prefix_len + num_generation_tokens].tolist())
427
+ yield generated_code, ''
428
+ result = check_result(task_id, generated_code)
429
+ yield generated_code, result
430
  def get_example_input():
431
  ### this functions samples a case from humaneval-infilling as prefix and suffix
432
  task_id = random.choice(list(problems.keys()))
433
+ problem = problems[task_id]
434
+ prefix, suffix = problem['prompt'], problem['suffix']
435
+ test_case = problem['test']
436
+ return prefix, suffix, test_case, task_id
437
+
438
+ def check_result(task_id, completion):
439
+ # 从数据集中读取问题
440
+ problem = problems[task_id]
441
+ # 这里假设 `unsafe_execute` 是一个可以执行代码并返回结果的函数
442
+ result = unsafe_execute(problem, completion)
443
+ return result
444
+
445
 
446
 
447
  # --- Gradio UI ---
448
  css = '''
449
  .category-legend{display:none}
450
  '''
451
+
452
  def create_chatbot_demo():
453
  with gr.Blocks(css=css) as demo:
454
  gr.Markdown("# DreamOn: Diffusion Language Models For Code Infilling Beyond Fixed-size Canvas\nClick **Example Prompt** to get a prefix and suffix, then click **Generate** to generate code. Have fun!")
 
478
  lines=2
479
  )
480
 
481
+ # Test Case input
482
+ test_case_input = gr.Textbox(
483
+ label="Test Case",
484
+ placeholder="Enter your test case here...",
485
+ lines=2
486
+ )
487
+
488
+ # Hidden Task ID input
489
+ task_id_input = gr.Textbox(
490
+ label="Task ID",
491
+ placeholder="Task ID will be stored here...",
492
+ visible=False
493
+ )
494
+
495
+ # Result of execution
496
+ result_output = gr.Textbox(
497
+ label="Result of Execution",
498
+ placeholder="Execution result will be shown here...",
499
+ lines=2
500
+ )
501
+
502
  with gr.Row():
503
  sample_btn = gr.Button("Example Prompt")
504
 
 
547
  label="Top-P (0 disables)"
548
  )
549
  with gr.Row():
 
 
 
 
 
 
550
  top_k = gr.Slider(
551
  minimum=0,
552
  maximum=200,
 
554
  step=5,
555
  label="Top-K (0 disables)")
556
 
 
557
  with gr.Row():
558
  alg = gr.Radio(
559
  choices=['maskgit_plus', 'topk_margin', 'entropy'],
 
580
  value=True
581
  )
582
 
 
583
  # Connect the UI elements
584
  generation_inputs = [
585
  prefix_input,
 
593
  alg,
594
  alg_temp,
595
  visualization_delay,
596
+ pad_delete_righthand,
597
+ task_id_input
598
  ]
599
 
600
  generate_btn.click(
601
  fn=infilling_dream,
602
  inputs=generation_inputs,
603
+ outputs=[output_vis, result_output],
604
  show_progress="hidden"
605
  )
606
 
607
  clear_btn.click(
608
+ lambda: ("", "", "", "", "", ""), # Clear all inputs and outputs
609
  inputs=[],
610
+ outputs=[prefix_input, suffix_input, output_vis, test_case_input, result_output, task_id_input],
611
  queue=False
612
  )
613
 
614
  sample_btn.click(
615
  fn=get_example_input,
616
+ outputs=[prefix_input, suffix_input, test_case_input, task_id_input],
617
  queue=False
618
  )
619
  return demo
620
 
 
 
621
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
622
 
623
  # --- Launch ---
624
  if __name__ == "__main__":
625
  #test()
626
  demo = create_chatbot_demo()
627
+ demo.queue().launch(debug=True, share=True)
 
eval.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import contextlib
3
+ import faulthandler
4
+ import io
5
+ import multiprocessing
6
+ import os
7
+ import platform
8
+ import signal
9
+ import tempfile
10
+ from typing import Callable, Dict, Optional
11
+
12
+
13
+ def unsafe_execute(problem, completion, timeout=3):
14
+ with create_tempdir():
15
+ # Construct the check program
16
+ check_program = (
17
+ problem["prompt"]
18
+ + completion
19
+ + problem["suffix"]
20
+ + "\n"
21
+ + problem["test"]
22
+ + "\n"
23
+ + f"check({problem['entry_point']})"
24
+ )
25
+
26
+ # Use multiprocessing to execute the code in a separate process
27
+ result_queue = multiprocessing.Queue()
28
+
29
+ def worker(check_program, result_queue):
30
+ try:
31
+ exec_globals = {}
32
+ with swallow_io():
33
+ with time_limit(timeout):
34
+ exec(check_program, exec_globals)
35
+ result_queue.put("passed")
36
+ except TimeoutException:
37
+ result_queue.put("timed out")
38
+ except BaseException as e:
39
+ result_queue.put(f"failed: {e}")
40
+
41
+ process = multiprocessing.Process(target=worker, args=(check_program, result_queue))
42
+ process.start()
43
+ process.join(timeout + 1) # Give some extra time for cleanup
44
+
45
+ if process.is_alive():
46
+ process.terminate()
47
+ result = "timed out"
48
+ else:
49
+ result = result_queue.get()
50
+
51
+ return result
52
+
53
+
54
+ @contextlib.contextmanager
55
+ def time_limit(seconds: float):
56
+ def signal_handler(signum, frame):
57
+ raise TimeoutException("Timed out!")
58
+
59
+ signal.setitimer(signal.ITIMER_REAL, seconds)
60
+ signal.signal(signal.SIGALRM, signal_handler)
61
+ try:
62
+ yield
63
+ finally:
64
+ signal.setitimer(signal.ITIMER_REAL, 0)
65
+
66
+
67
+ @contextlib.contextmanager
68
+ def swallow_io():
69
+ stream = WriteOnlyStringIO()
70
+ with contextlib.redirect_stdout(stream):
71
+ with contextlib.redirect_stderr(stream):
72
+ with redirect_stdin(stream):
73
+ yield
74
+
75
+
76
+ @contextlib.contextmanager
77
+ def create_tempdir():
78
+ with tempfile.TemporaryDirectory() as dirname:
79
+ with chdir(dirname):
80
+ yield dirname
81
+
82
+
83
+ class TimeoutException(Exception):
84
+ pass
85
+
86
+
87
+ class WriteOnlyStringIO(io.StringIO):
88
+ """StringIO that throws an exception when it's read from"""
89
+
90
+ def read(self, *args, **kwargs):
91
+ raise IOError
92
+
93
+ def readline(self, *args, **kwargs):
94
+ raise IOError
95
+
96
+ def readlines(self, *args, **kwargs):
97
+ raise IOError
98
+
99
+ def readable(self, *args, **kwargs):
100
+ """Returns True if the IO object can be read."""
101
+ return False
102
+
103
+
104
+ class redirect_stdin(contextlib._RedirectStream): # type: ignore
105
+ _stream = "stdin"
106
+
107
+
108
+ @contextlib.contextmanager
109
+ def chdir(root):
110
+ if root == ".":
111
+ yield
112
+ return
113
+ cwd = os.getcwd()
114
+ os.chdir(root)
115
+ try:
116
+ yield
117
+ except BaseException as exc:
118
+ raise exc
119
+ finally:
120
+ os.chdir(cwd)
121
+
122
+
123
+ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
124
+ """
125
+ This disables various destructive functions and prevents the generated code
126
+ from interfering with the test (e.g. fork bomb, killing other processes,
127
+ removing filesystem files, etc.)
128
+
129
+ WARNING
130
+ This function is NOT a security sandbox. Untrusted code, including, model-
131
+ generated code, should not be blindly executed outside of one. See the
132
+ Codex paper for more information about OpenAI's code sandbox, and proceed
133
+ with caution.
134
+ """
135
+
136
+ if maximum_memory_bytes is not None:
137
+ import resource
138
+
139
+ resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
140
+ resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
141
+ if not platform.uname().system == "Darwin":
142
+ resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
143
+
144
+ faulthandler.disable()
145
+
146
+ import builtins
147
+
148
+ builtins.exit = None
149
+ builtins.quit = None
150
+
151
+ import os
152
+
153
+ os.environ["OMP_NUM_THREADS"] = "1"
154
+
155
+ os.kill = None
156
+ os.system = None
157
+ os.putenv = None
158
+ os.remove = None
159
+ os.removedirs = None
160
+ os.rmdir = None
161
+ os.fchdir = None
162
+ os.setuid = None
163
+ os.fork = None
164
+ os.forkpty = None
165
+ os.killpg = None
166
+ os.rename = None
167
+ os.renames = None
168
+ os.truncate = None
169
+ os.replace = None
170
+ os.unlink = None
171
+ os.fchmod = None
172
+ os.fchown = None
173
+ os.chmod = None
174
+ os.chown = None
175
+ os.chroot = None
176
+ os.fchdir = None
177
+ os.lchflags = None
178
+ os.lchmod = None
179
+ os.lchown = None
180
+ os.getcwd = None
181
+ os.chdir = None
182
+
183
+ import shutil
184
+
185
+ shutil.rmtree = None
186
+ shutil.move = None
187
+ shutil.chown = None
188
+
189
+ import subprocess
190
+
191
+ subprocess.Popen = None # type: ignore
192
+
193
+ __builtins__["help"] = None
194
+
195
+ import sys
196
+
197
+ sys.modules["ipdb"] = None
198
+ sys.modules["joblib"] = None
199
+ sys.modules["resource"] = None
200
+ sys.modules["psutil"] = None
201
+ sys.modules["tkinter"] = None