oliver-aizip commited on
Commit
665e5a3
·
1 Parent(s): 8151596

first pass at async handling

Browse files
Files changed (2) hide show
  1. app.py +55 -3
  2. utils/models.py +6 -4
app.py CHANGED
@@ -41,7 +41,7 @@ def load_context():
41
  show_full
42
  ]
43
 
44
- def generate_model_summaries_with_timeout(example, timeout=30):
45
  """Run model inference in a separate thread with timeout for interruptibility"""
46
  import threading
47
  import time
@@ -75,6 +75,7 @@ def generate_model_summaries_with_timeout(example, timeout=30):
75
  generation_thread.daemon = True
76
  generation_thread.start()
77
 
 
78
  start_time = time.time()
79
  while time.time() - start_time < timeout:
80
  if generation_interrupt.is_set() or not generation_thread.is_alive() or result["completed"]:
@@ -83,6 +84,50 @@ def generate_model_summaries_with_timeout(example, timeout=30):
83
 
84
  return result
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  def process_generation_result(result):
87
  """Process the results from the threaded generation function"""
88
  if not result["completed"]:
@@ -122,6 +167,13 @@ def process_generation_result(result):
122
  gr.update(interactive=True),
123
  gr.update(elem_classes=[])
124
  ]
 
 
 
 
 
 
 
125
 
126
  def select_vote_improved(winner_choice):
127
  """Updates UI based on vote selection"""
@@ -346,7 +398,7 @@ with gr.Blocks(theme=gr.themes.Default(
346
  outputs=[current_example, query_display, context_description, context_display,
347
  context_toggle_btn, show_full_context]
348
  ).then(
349
- fn=lambda example: process_generation_result(generate_model_summaries_with_timeout(example)),
350
  inputs=[current_example],
351
  outputs=[model_a_name, model_b_name, summary_a_text, summary_b_text,
352
  selected_winner, feedback_list, show_results_state, results_agg,
@@ -367,7 +419,7 @@ with gr.Blocks(theme=gr.themes.Default(
367
  outputs=[query_display, context_description, context_display,
368
  context_toggle_btn, show_full_context]
369
  ).then(
370
- fn=lambda example: process_generation_result(generate_model_summaries_with_timeout(example)),
371
  inputs=[current_example],
372
  outputs=[model_a_name, model_b_name, summary_a_text, summary_b_text,
373
  selected_winner, feedback_list, show_results_state, results_agg,
 
41
  show_full
42
  ]
43
 
44
+ def generate_model_summaries_with_timeout(example, timeout=60):
45
  """Run model inference in a separate thread with timeout for interruptibility"""
46
  import threading
47
  import time
 
75
  generation_thread.daemon = True
76
  generation_thread.start()
77
 
78
+ # Uncomment this critical waiting code
79
  start_time = time.time()
80
  while time.time() - start_time < timeout:
81
  if generation_interrupt.is_set() or not generation_thread.is_alive() or result["completed"]:
 
84
 
85
  return result
86
 
87
+ async def generate_model_summaries_with_timeout_async(example, timeout=30):
88
+ """Async version that properly waits for the thread"""
89
+ import asyncio
90
+ import threading
91
+ import time
92
+
93
+ result = {
94
+ "model_a": "",
95
+ "model_b": "",
96
+ "summary_a": "",
97
+ "summary_b": "",
98
+ "completed": False
99
+ }
100
+
101
+ if generation_interrupt.is_set():
102
+ return result
103
+
104
+ def run_generation():
105
+ try:
106
+ m_a_name, m_b_name = random.sample(model_names, 2)
107
+ s_a, s_b = generate_summaries(example, m_a_name, m_b_name)
108
+
109
+ if not generation_interrupt.is_set():
110
+ result["model_a"] = m_a_name
111
+ result["model_b"] = m_b_name
112
+ result["summary_a"] = s_a
113
+ result["summary_b"] = s_b
114
+ result["completed"] = True
115
+ except Exception as e:
116
+ print(f"Error in generation thread: {e}")
117
+
118
+ generation_thread = threading.Thread(target=run_generation)
119
+ generation_thread.daemon = True
120
+ generation_thread.start()
121
+
122
+ # Use asyncio.sleep instead of time.sleep for async waiting
123
+ start_time = time.time()
124
+ while time.time() - start_time < timeout:
125
+ if generation_interrupt.is_set() or not generation_thread.is_alive() or result["completed"]:
126
+ break
127
+ await asyncio.sleep(0.1) # Non-blocking sleep
128
+
129
+ return result
130
+
131
  def process_generation_result(result):
132
  """Process the results from the threaded generation function"""
133
  if not result["completed"]:
 
167
  gr.update(interactive=True),
168
  gr.update(elem_classes=[])
169
  ]
170
+ async def process_example_async(example):
171
+ result = await generate_model_summaries_with_timeout_async(example)
172
+ return process_generation_result(result)
173
+
174
+ def process_example_sync(example):
175
+ result = generate_model_summaries_with_timeout(example)
176
+ return process_generation_result(result)
177
 
178
  def select_vote_improved(winner_choice):
179
  """Updates UI based on vote selection"""
 
398
  outputs=[current_example, query_display, context_description, context_display,
399
  context_toggle_btn, show_full_context]
400
  ).then(
401
+ fn=process_example_async,
402
  inputs=[current_example],
403
  outputs=[model_a_name, model_b_name, summary_a_text, summary_b_text,
404
  selected_winner, feedback_list, show_results_state, results_agg,
 
419
  outputs=[query_display, context_description, context_display,
420
  context_toggle_btn, show_full_context]
421
  ).then(
422
+ fn=process_example_sync,
423
  inputs=[current_example],
424
  outputs=[model_a_name, model_b_name, summary_a_text, summary_b_text,
425
  selected_winner, feedback_list, show_results_state, results_agg,
utils/models.py CHANGED
@@ -14,7 +14,10 @@ from .prompts import format_rag_prompt
14
  models = {
15
  "Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
16
  "Qwen2.5-3b-Instruct": "qwen/qwen2.5-3b-instruct", # remove gated for now
17
- #"Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
 
 
 
18
  #TODO add more models
19
  }
20
 
@@ -47,7 +50,6 @@ def generate_summaries(example, model_a_name, model_b_name):
47
  summary_b = run_inference(models[model_b_name], context_text, question)
48
  return summary_a, summary_b
49
 
50
-
51
  def run_inference(model_name, context, question):
52
  """
53
  Run inference using the specified model.
@@ -55,7 +57,7 @@ def run_inference(model_name, context, question):
55
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
 
57
  # Load the model and tokenizer
58
- tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
59
  accepts_sys = (
60
  "System role not supported" not in tokenizer.chat_template
61
  ) # Workaround for Gemma
@@ -65,7 +67,7 @@ def run_inference(model_name, context, question):
65
  tokenizer.pad_token = tokenizer.eos_token
66
 
67
  model = AutoModelForCausalLM.from_pretrained(
68
- model_name, torch_dtype=torch.bfloat16, attn_implementation="eager"
69
  ).to(device)
70
 
71
  text_input = format_rag_prompt(question, context, accepts_sys)
 
14
  models = {
15
  "Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
16
  "Qwen2.5-3b-Instruct": "qwen/qwen2.5-3b-instruct", # remove gated for now
17
+ "Llama-3.2-3b-Instruct": "meta-llama/llama-3.2-3b-instruct",
18
+ "Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
19
+ "Gemma-3-1b-it" : "google/gemma-3-1b-it",
20
+ #"Bitnet-b1.58-2B-4T": "microsoft/bitnet-b1.58-2B-4T",
21
  #TODO add more models
22
  }
23
 
 
50
  summary_b = run_inference(models[model_b_name], context_text, question)
51
  return summary_a, summary_b
52
 
 
53
  def run_inference(model_name, context, question):
54
  """
55
  Run inference using the specified model.
 
57
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
 
59
  # Load the model and tokenizer
60
+ tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=True)
61
  accepts_sys = (
62
  "System role not supported" not in tokenizer.chat_template
63
  ) # Workaround for Gemma
 
67
  tokenizer.pad_token = tokenizer.eos_token
68
 
69
  model = AutoModelForCausalLM.from_pretrained(
70
+ model_name, torch_dtype=torch.bfloat16, attn_implementation="eager", token=True
71
  ).to(device)
72
 
73
  text_input = format_rag_prompt(question, context, accepts_sys)