VisoLearn commited on
Commit
8140621
·
verified ·
1 Parent(s): f6ac5ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -19
app.py CHANGED
@@ -3,6 +3,7 @@ import spaces
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
  import torch
5
  from threading import Thread
 
6
 
7
  phi4_model_path = "Intelligent-Internet/II-Medical-8B"
8
 
@@ -11,10 +12,12 @@ device = "cuda:0" if torch.cuda.is_available() else "cpu"
11
  phi4_model = AutoModelForCausalLM.from_pretrained(phi4_model_path, device_map="auto", torch_dtype="auto")
12
  phi4_tokenizer = AutoTokenizer.from_pretrained(phi4_model_path)
13
 
 
14
  @spaces.GPU(duration=60)
15
- def generate_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history):
16
  if not user_message.strip():
17
- return history, history
 
18
 
19
  model = phi4_model
20
  tokenizer = phi4_tokenizer
@@ -60,6 +63,7 @@ Now, analyze the following case:"""
60
  "streamer": streamer,
61
  }
62
 
 
63
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
64
  thread.start()
65
 
@@ -73,10 +77,21 @@ Now, analyze the following case:"""
73
  assistant_response += cleaned_token
74
  # Update the last message in history with the current response
75
  new_history[-1][1] = assistant_response.strip()
76
-
77
- # Return the updated history
78
- return new_history, new_history
79
-
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  example_messages = {
82
  "Headache case": "A 35-year-old female presents with a throbbing headache, nausea, and sensitivity to light. It started on one side of her head and worsens with activity. No prior trauma.",
@@ -150,23 +165,47 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
150
  example3 = gr.Button("Abdominal pain")
151
  example4 = gr.Button("BMI calculation")
152
 
153
- # Use click instead of stream
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  submit_button.click(
155
- fn=generate_response,
156
- inputs=[user_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider,
157
- repetition_penalty_slider, history],
158
- outputs=[chatbot, history]
159
  ).then(
160
- fn=lambda: gr.update(value=""),
161
- inputs=None,
162
- outputs=user_input
163
  )
164
 
165
- clear_button.click(fn=lambda: ([], []), inputs=None, outputs=[chatbot, history])
 
 
166
 
167
- example1.click(lambda: gr.update(value=example_messages["Headache case"]), None, user_input)
168
- example2.click(lambda: gr.update(value=example_messages["Chest pain"]), None, user_input)
169
- example3.click(lambda: gr.update(value=example_messages["Abdominal pain"]), None, user_input)
170
- example4.click(lambda: gr.update(value=example_messages["BMI calculation"]), None, user_input)
 
171
 
172
  demo.launch(ssr_mode=False)
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
  import torch
5
  from threading import Thread
6
+ import time
7
 
8
  phi4_model_path = "Intelligent-Internet/II-Medical-8B"
9
 
 
12
  phi4_model = AutoModelForCausalLM.from_pretrained(phi4_model_path, device_map="auto", torch_dtype="auto")
13
  phi4_tokenizer = AutoTokenizer.from_pretrained(phi4_model_path)
14
 
15
+ # This is our streaming generator function that yields partial results
16
  @spaces.GPU(duration=60)
17
+ def generate_streaming_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history):
18
  if not user_message.strip():
19
+ yield history, history
20
+ return
21
 
22
  model = phi4_model
23
  tokenizer = phi4_tokenizer
 
63
  "streamer": streamer,
64
  }
65
 
66
+ # Start generation in a separate thread
67
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
68
  thread.start()
69
 
 
77
  assistant_response += cleaned_token
78
  # Update the last message in history with the current response
79
  new_history[-1][1] = assistant_response.strip()
80
+ yield new_history, new_history
81
+ # Add a small sleep to control the streaming rate
82
+ time.sleep(0.01)
83
+
84
+ # Return the final state after streaming is completed
85
+ yield new_history, new_history
86
+
87
+ # This is our non-streaming wrapper function for buttons that don't support streaming
88
+ def process_input(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history):
89
+ generator = generate_streaming_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history)
90
+ # Get the final result by exhausting the generator
91
+ result = None
92
+ for result in generator:
93
+ pass
94
+ return result
95
 
96
  example_messages = {
97
  "Headache case": "A 35-year-old female presents with a throbbing headache, nausea, and sensitivity to light. It started on one side of her head and worsens with activity. No prior trauma.",
 
165
  example3 = gr.Button("Abdominal pain")
166
  example4 = gr.Button("BMI calculation")
167
 
168
+ # Set up the streaming interface
169
+ def on_submit(message, history, max_tokens, temperature, top_k, top_p, repetition_penalty):
170
+ # Return the modified history that includes the new user message
171
+ modified_history = history + [[message, ""]]
172
+ return "", modified_history, modified_history
173
+
174
+ def on_stream(history, max_tokens, temperature, top_k, top_p, repetition_penalty):
175
+ if not history:
176
+ return history
177
+
178
+ # Get the last user message from history
179
+ user_message = history[-1][0]
180
+
181
+ # Start a fresh history without the last entry
182
+ prev_history = history[:-1]
183
+
184
+ # Generate streaming responses
185
+ for new_history, _ in generate_streaming_response(
186
+ user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, prev_history
187
+ ):
188
+ yield new_history
189
+
190
+ # Connect the submission event
191
  submit_button.click(
192
+ fn=on_submit,
193
+ inputs=[user_input, history, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider],
194
+ outputs=[user_input, chatbot, history]
 
195
  ).then(
196
+ fn=on_stream,
197
+ inputs=[history, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider],
198
+ outputs=chatbot
199
  )
200
 
201
+ # Handle examples
202
+ def set_example(example_text):
203
+ return gr.update(value=example_text)
204
 
205
+ clear_button.click(fn=lambda: ([], []), inputs=None, outputs=[chatbot, history])
206
+ example1.click(fn=lambda: set_example(example_messages["Headache case"]), inputs=None, outputs=user_input)
207
+ example2.click(fn=lambda: set_example(example_messages["Chest pain"]), inputs=None, outputs=user_input)
208
+ example3.click(fn=lambda: set_example(example_messages["Abdominal pain"]), inputs=None, outputs=user_input)
209
+ example4.click(fn=lambda: set_example(example_messages["BMI calculation"]), inputs=None, outputs=user_input)
210
 
211
  demo.launch(ssr_mode=False)