VisoLearn commited on
Commit
afa4eb5
·
verified ·
1 Parent(s): fa6be8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -75
app.py CHANGED
@@ -1,131 +1,185 @@
1
- import spaces
2
  import gradio as gr
3
- from transformers import AutoTokenizer, TextIteratorStreamer
4
- from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
5
  import torch
6
  from threading import Thread
 
 
 
 
 
7
 
8
- # Model and device configuration
9
- phi4_model_path = "Compumacy/OpenBioLLm-70B"
10
- # Specify the base filename of the GPTQ checkpoint in the repo
11
- model_basename = "gptq_model-2bit-128g.safetensors"
12
- device = "cuda" if torch.cuda.is_available() else "cpu"
13
-
14
- # === GPTQ 2-bit QUANTIZATION CONFIG ===
15
- quantize_config = BaseQuantizeConfig(
16
- bits=2, # 2-bit quantization
17
- group_size=128, # grouping size
18
- desc_act=False # disable descending activations
19
- )
20
-
21
- # === LOAD GPTQ-QUANTIZED MODEL ===
22
- model = AutoGPTQForCausalLM.from_quantized(
23
- phi4_model_path,
24
- model_basename=model_basename,
25
- quantize_config=quantize_config,
26
- device_map="auto",
27
- use_safetensors=True,
28
- )
29
-
30
- tokenizer = AutoTokenizer.from_pretrained(phi4_model_path)
31
-
32
- # === OPTIONAL: TorchCompile for optimization (PyTorch >= 2.0) ===
33
- try:
34
- model = torch.compile(model)
35
- except Exception:
36
- pass
37
-
38
- # === STREAMING RESPONSE GENERATOR ===
39
- @spaces.GPU()
40
  def generate_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history_state):
41
  if not user_message.strip():
42
  return history_state, history_state
 
 
 
 
 
 
43
 
44
- system_message = (
45
- "Your role as an assistant involves thoroughly exploring questions through a systematic thinking process..."
46
- )
47
- start_tag, sep_tag, end_tag = "<|im_start|>", "<|im_sep|>", "<|im_end|>"
48
 
49
- # Build prompt
 
 
 
 
 
 
 
 
 
 
 
 
50
  prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}"
51
- for msg in history_state:
52
- prompt += f"{start_tag}{msg['role']}{sep_tag}{msg['content']}{end_tag}"
 
 
 
53
  prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}"
54
 
55
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
56
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
 
57
  generation_kwargs = {
58
- "input_ids": inputs.input_ids,
59
- "attention_mask": inputs.attention_mask,
60
  "max_new_tokens": int(max_tokens),
61
  "do_sample": True,
62
- "temperature": temperature,
63
  "top_k": int(top_k),
64
- "top_p": top_p,
65
- "repetition_penalty": repetition_penalty,
66
- "streamer": streamer
67
  }
68
 
69
- Thread(target=model.generate, kwargs=generation_kwargs).start()
 
70
 
71
  assistant_response = ""
72
  new_history = history_state + [
73
  {"role": "user", "content": user_message},
74
  {"role": "assistant", "content": ""}
75
  ]
76
-
77
- for token in streamer:
78
- clean = token.replace(start_tag, "").replace(sep_tag, "").replace(end_tag, "")
79
- assistant_response += clean
80
- new_history[-1]["content"] = assistant_response
81
  yield new_history, new_history
82
 
83
  yield new_history, new_history
84
 
85
- # === EXAMPLES ===
86
  example_messages = {
87
- "Math reasoning": "If a rectangular prism has a length of 6 cm...",
88
- "Logic puzzle": "Four people (Alex, Blake, Casey, ...)",
89
- "Physics problem": "A ball is thrown upward with an initial velocity..."
 
90
  }
91
 
92
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
93
- gr.Markdown("""
94
- # Phi-4 Chat with GPTQ Quant
95
- Try the example problems below to see how the model breaks down complex reasoning.
96
- """ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  history_state = gr.State([])
 
99
  with gr.Row():
100
  with gr.Column(scale=1):
101
  gr.Markdown("### Settings")
102
- max_tokens_slider = gr.Slider(64, 32768, step=1024, value=2048, label="Max Tokens")
103
  with gr.Accordion("Advanced Settings", open=False):
104
  temperature_slider = gr.Slider(0.1, 2.0, value=0.8, label="Temperature")
105
  top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k")
106
  top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p")
107
  repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.0, label="Repetition Penalty")
 
108
  with gr.Column(scale=4):
109
- chatbot = gr.Chatbot(label="Chat", type="messages")
110
  with gr.Row():
111
- user_input = gr.Textbox(placeholder="Type your message...", scale=3)
112
  submit_button = gr.Button("Send", variant="primary", scale=1)
113
  clear_button = gr.Button("Clear", scale=1)
114
- gr.Markdown("**Try these examples:**")
115
  with gr.Row():
116
- for name, text in example_messages.items():
117
- btn = gr.Button(name)
118
- btn.click(fn=lambda t=text: gr.update(value=t), inputs=None, outputs=user_input)
 
119
 
120
  submit_button.click(
121
  fn=generate_response,
122
  inputs=[user_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider, history_state],
123
  outputs=[chatbot, history_state]
124
- ).then(lambda: gr.update(value=""), None, user_input)
 
 
 
 
 
 
 
 
 
 
125
 
126
- clear_button.click(lambda: ([], []), None, [chatbot, history_state])
 
 
 
127
 
128
  demo.launch(ssr_mode=False)
129
-
130
- # If you still see missing CUDA kernels warnings, reinstall AutoGPTQ with CUDA support:
131
- # pip install git+https://github.com/PanQiWei/AutoGPTQ.git#egg=auto-gptq[cuda]
 
 
1
  import gradio as gr
2
+ import spaces
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
  import torch
5
  from threading import Thread
6
+ import re
7
+
8
+ phi4_model_path = "Intelligent-Internet/II-Medical-8B"
9
+
10
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
11
 
12
+ phi4_model = AutoModelForCausalLM.from_pretrained(phi4_model_path, device_map="auto", torch_dtype="auto")
13
+ phi4_tokenizer = AutoTokenizer.from_pretrained(phi4_model_path)
14
+
15
+ @spaces.GPU(duration=60)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def generate_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history_state):
17
  if not user_message.strip():
18
  return history_state, history_state
19
+
20
+ model = phi4_model
21
+ tokenizer = phi4_tokenizer
22
+ start_tag = "<|im_start|>"
23
+ sep_tag = "<|im_sep|>"
24
+ end_tag = "<|im_end|>"
25
 
26
+ system_message = """You are a highly knowledgeable and thoughtful AI medical assistant. Your primary role is to assist with diagnostic reasoning by evaluating patient symptoms, medical history, and relevant clinical context.
27
+
28
+ Structure your response into two main sections using the following format: <think> {Thought section} </think> {Solution section}.
 
29
 
30
+ In the <think> section, use structured clinical reasoning to:
31
+ - Identify possible differential diagnoses based on the given symptoms.
32
+ - Consider risk factors, medical history, duration, and severity of symptoms.
33
+ - Use step-by-step logic to rule in or rule out conditions.
34
+ - Reflect on diagnostic uncertainty and suggest further assessments if needed.
35
+
36
+ In the <solution> section, provide your most likely diagnosis or clinical assessment along with the rationale. Include brief suggestions for potential next steps like labs, imaging, or referrals if appropriate.
37
+
38
+ IMPORTANT: When referencing lab values or pathophysiological mechanisms, use LaTeX formatting for clarity. Use $...$ for inline and $$...$$ for block-level expressions.
39
+
40
+ Now, please analyze and respond to the following case:
41
+ """
42
+
43
  prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}"
44
+ for message in history_state:
45
+ if message["role"] == "user":
46
+ prompt += f"{start_tag}user{sep_tag}{message['content']}{end_tag}"
47
+ elif message["role"] == "assistant" and message["content"]:
48
+ prompt += f"{start_tag}assistant{sep_tag}{message['content']}{end_tag}"
49
  prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}"
50
 
51
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
52
+
53
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
54
+
55
  generation_kwargs = {
56
+ "input_ids": inputs["input_ids"],
57
+ "attention_mask": inputs["attention_mask"],
58
  "max_new_tokens": int(max_tokens),
59
  "do_sample": True,
60
+ "temperature": float(temperature),
61
  "top_k": int(top_k),
62
+ "top_p": float(top_p),
63
+ "repetition_penalty": float(repetition_penalty),
64
+ "streamer": streamer,
65
  }
66
 
67
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
68
+ thread.start()
69
 
70
  assistant_response = ""
71
  new_history = history_state + [
72
  {"role": "user", "content": user_message},
73
  {"role": "assistant", "content": ""}
74
  ]
75
+
76
+ for new_token in streamer:
77
+ cleaned_token = new_token.replace("<|im_start|>", "").replace("<|im_sep|>", "").replace("<|im_end|>", "")
78
+ assistant_response += cleaned_token
79
+ new_history[-1]["content"] = assistant_response.strip()
80
  yield new_history, new_history
81
 
82
  yield new_history, new_history
83
 
84
+ # Updated example cases for medical diagnostics
85
  example_messages = {
86
+ "Chest Pain": "A 58-year-old man presents with chest pain that started 20 minutes ago while climbing stairs. He describes it as a heavy pressure in the center of his chest, radiating to his left arm. He has a history of hypertension and smoking. What is the likely diagnosis?",
87
+ "Shortness of Breath": "A 34-year-old woman presents with 3 days of worsening shortness of breath, low-grade fever, and a dry cough. She denies chest pain or recent travel. Pulse oximetry is 91% on room air.",
88
+ "Abdominal Pain": "A 22-year-old female presents with lower right quadrant abdominal pain, nausea, and fever. The pain started around the umbilicus and migrated to the right lower quadrant over the past 12 hours.",
89
+ "Pediatric Fever": "A 2-year-old child has a fever of 39.5°C, irritability, and a rash on the trunk and arms. The child received all standard vaccinations and has no sick contacts. What should be considered in the differential diagnosis?"
90
  }
91
 
92
+ # Custom CSS
93
+ css = """
94
+ .markdown-body .katex {
95
+ font-size: 1.2em;
96
+ }
97
+ .markdown-body .katex-display {
98
+ margin: 1em 0;
99
+ overflow-x: auto;
100
+ overflow-y: hidden;
101
+ }
102
+ """
103
+
104
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
105
+ gr.Markdown(
106
+ """
107
+ # Medical Diagnosis Assistant
108
+ This AI assistant uses structured reasoning to evaluate clinical cases and assist with diagnostic decision-making. Includes LaTeX support for medical calculations.
109
+ """
110
+ )
111
+
112
+ gr.HTML("""
113
+ <script>
114
+ if (typeof window.MathJax === 'undefined') {
115
+ const script = document.createElement('script');
116
+ script.src = 'https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/MathJax.js?config=TeX-MML-AM_CHTML';
117
+ script.async = true;
118
+ document.head.appendChild(script);
119
+ window.MathJax = {
120
+ tex2jax: {
121
+ inlineMath: [['$', '$']],
122
+ displayMath: [['$$', '$$']],
123
+ processEscapes: true
124
+ },
125
+ showProcessingMessages: false,
126
+ messageStyle: 'none'
127
+ };
128
+ }
129
+
130
+ function rerender() {
131
+ if (window.MathJax && window.MathJax.Hub) {
132
+ window.MathJax.Hub.Queue(['Typeset', window.MathJax.Hub]);
133
+ }
134
+ }
135
+ setInterval(rerender, 1000);
136
+ </script>
137
+ """)
138
 
139
  history_state = gr.State([])
140
+
141
  with gr.Row():
142
  with gr.Column(scale=1):
143
  gr.Markdown("### Settings")
144
+ max_tokens_slider = gr.Slider(64, 32768, step=1024, value=4096, label="Max Tokens")
145
  with gr.Accordion("Advanced Settings", open=False):
146
  temperature_slider = gr.Slider(0.1, 2.0, value=0.8, label="Temperature")
147
  top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k")
148
  top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p")
149
  repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.0, label="Repetition Penalty")
150
+
151
  with gr.Column(scale=4):
152
+ chatbot = gr.Chatbot(label="Chat", render_markdown=True, type="messages", elem_id="chatbot", show_copy_button=True)
153
  with gr.Row():
154
+ user_input = gr.Textbox(label="Describe patient symptoms...", placeholder="Type a clinical case here...", scale=3)
155
  submit_button = gr.Button("Send", variant="primary", scale=1)
156
  clear_button = gr.Button("Clear", scale=1)
157
+ gr.Markdown("**Try these example cases:**")
158
  with gr.Row():
159
+ example1_button = gr.Button("Chest Pain")
160
+ example2_button = gr.Button("Shortness of Breath")
161
+ example3_button = gr.Button("Abdominal Pain")
162
+ example4_button = gr.Button("Pediatric Fever")
163
 
164
  submit_button.click(
165
  fn=generate_response,
166
  inputs=[user_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider, history_state],
167
  outputs=[chatbot, history_state]
168
+ ).then(
169
+ fn=lambda: gr.update(value=""),
170
+ inputs=None,
171
+ outputs=user_input
172
+ )
173
+
174
+ clear_button.click(
175
+ fn=lambda: ([], []),
176
+ inputs=None,
177
+ outputs=[chatbot, history_state]
178
+ )
179
 
180
+ example1_button.click(fn=lambda: gr.update(value=example_messages["Chest Pain"]), inputs=None, outputs=user_input)
181
+ example2_button.click(fn=lambda: gr.update(value=example_messages["Shortness of Breath"]), inputs=None, outputs=user_input)
182
+ example3_button.click(fn=lambda: gr.update(value=example_messages["Abdominal Pain"]), inputs=None, outputs=user_input)
183
+ example4_button.click(fn=lambda: gr.update(value=example_messages["Pediatric Fever"]), inputs=None, outputs=user_input)
184
 
185
  demo.launch(ssr_mode=False)