Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,131 +1,185 @@
|
|
1 |
-
import spaces
|
2 |
import gradio as gr
|
3 |
-
|
4 |
-
from
|
5 |
import torch
|
6 |
from threading import Thread
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
-
|
14 |
-
# === GPTQ 2-bit QUANTIZATION CONFIG ===
|
15 |
-
quantize_config = BaseQuantizeConfig(
|
16 |
-
bits=2, # 2-bit quantization
|
17 |
-
group_size=128, # grouping size
|
18 |
-
desc_act=False # disable descending activations
|
19 |
-
)
|
20 |
-
|
21 |
-
# === LOAD GPTQ-QUANTIZED MODEL ===
|
22 |
-
model = AutoGPTQForCausalLM.from_quantized(
|
23 |
-
phi4_model_path,
|
24 |
-
model_basename=model_basename,
|
25 |
-
quantize_config=quantize_config,
|
26 |
-
device_map="auto",
|
27 |
-
use_safetensors=True,
|
28 |
-
)
|
29 |
-
|
30 |
-
tokenizer = AutoTokenizer.from_pretrained(phi4_model_path)
|
31 |
-
|
32 |
-
# === OPTIONAL: TorchCompile for optimization (PyTorch >= 2.0) ===
|
33 |
-
try:
|
34 |
-
model = torch.compile(model)
|
35 |
-
except Exception:
|
36 |
-
pass
|
37 |
-
|
38 |
-
# === STREAMING RESPONSE GENERATOR ===
|
39 |
-
@spaces.GPU()
|
40 |
def generate_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history_state):
|
41 |
if not user_message.strip():
|
42 |
return history_state, history_state
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
-
system_message =
|
45 |
-
|
46 |
-
|
47 |
-
start_tag, sep_tag, end_tag = "<|im_start|>", "<|im_sep|>", "<|im_end|>"
|
48 |
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}"
|
51 |
-
for
|
52 |
-
|
|
|
|
|
|
|
53 |
prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}"
|
54 |
|
55 |
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
|
|
56 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
|
|
|
57 |
generation_kwargs = {
|
58 |
-
"input_ids": inputs
|
59 |
-
"attention_mask": inputs
|
60 |
"max_new_tokens": int(max_tokens),
|
61 |
"do_sample": True,
|
62 |
-
"temperature": temperature,
|
63 |
"top_k": int(top_k),
|
64 |
-
"top_p": top_p,
|
65 |
-
"repetition_penalty": repetition_penalty,
|
66 |
-
"streamer": streamer
|
67 |
}
|
68 |
|
69 |
-
Thread(target=model.generate, kwargs=generation_kwargs)
|
|
|
70 |
|
71 |
assistant_response = ""
|
72 |
new_history = history_state + [
|
73 |
{"role": "user", "content": user_message},
|
74 |
{"role": "assistant", "content": ""}
|
75 |
]
|
76 |
-
|
77 |
-
for
|
78 |
-
|
79 |
-
assistant_response +=
|
80 |
-
new_history[-1]["content"] = assistant_response
|
81 |
yield new_history, new_history
|
82 |
|
83 |
yield new_history, new_history
|
84 |
|
85 |
-
#
|
86 |
example_messages = {
|
87 |
-
"
|
88 |
-
"
|
89 |
-
"
|
|
|
90 |
}
|
91 |
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
history_state = gr.State([])
|
|
|
99 |
with gr.Row():
|
100 |
with gr.Column(scale=1):
|
101 |
gr.Markdown("### Settings")
|
102 |
-
max_tokens_slider = gr.Slider(64, 32768, step=1024, value=
|
103 |
with gr.Accordion("Advanced Settings", open=False):
|
104 |
temperature_slider = gr.Slider(0.1, 2.0, value=0.8, label="Temperature")
|
105 |
top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k")
|
106 |
top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p")
|
107 |
repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.0, label="Repetition Penalty")
|
|
|
108 |
with gr.Column(scale=4):
|
109 |
-
chatbot = gr.Chatbot(label="Chat", type="messages")
|
110 |
with gr.Row():
|
111 |
-
user_input = gr.Textbox(placeholder="Type
|
112 |
submit_button = gr.Button("Send", variant="primary", scale=1)
|
113 |
clear_button = gr.Button("Clear", scale=1)
|
114 |
-
gr.Markdown("**Try these
|
115 |
with gr.Row():
|
116 |
-
|
117 |
-
|
118 |
-
|
|
|
119 |
|
120 |
submit_button.click(
|
121 |
fn=generate_response,
|
122 |
inputs=[user_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider, history_state],
|
123 |
outputs=[chatbot, history_state]
|
124 |
-
).then(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
-
|
|
|
|
|
|
|
127 |
|
128 |
demo.launch(ssr_mode=False)
|
129 |
-
|
130 |
-
# If you still see missing CUDA kernels warnings, reinstall AutoGPTQ with CUDA support:
|
131 |
-
# pip install git+https://github.com/PanQiWei/AutoGPTQ.git#egg=auto-gptq[cuda]
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import spaces
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
4 |
import torch
|
5 |
from threading import Thread
|
6 |
+
import re
|
7 |
+
|
8 |
+
phi4_model_path = "Intelligent-Internet/II-Medical-8B"
|
9 |
+
|
10 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
11 |
|
12 |
+
phi4_model = AutoModelForCausalLM.from_pretrained(phi4_model_path, device_map="auto", torch_dtype="auto")
|
13 |
+
phi4_tokenizer = AutoTokenizer.from_pretrained(phi4_model_path)
|
14 |
+
|
15 |
+
@spaces.GPU(duration=60)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
def generate_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history_state):
|
17 |
if not user_message.strip():
|
18 |
return history_state, history_state
|
19 |
+
|
20 |
+
model = phi4_model
|
21 |
+
tokenizer = phi4_tokenizer
|
22 |
+
start_tag = "<|im_start|>"
|
23 |
+
sep_tag = "<|im_sep|>"
|
24 |
+
end_tag = "<|im_end|>"
|
25 |
|
26 |
+
system_message = """You are a highly knowledgeable and thoughtful AI medical assistant. Your primary role is to assist with diagnostic reasoning by evaluating patient symptoms, medical history, and relevant clinical context.
|
27 |
+
|
28 |
+
Structure your response into two main sections using the following format: <think> {Thought section} </think> {Solution section}.
|
|
|
29 |
|
30 |
+
In the <think> section, use structured clinical reasoning to:
|
31 |
+
- Identify possible differential diagnoses based on the given symptoms.
|
32 |
+
- Consider risk factors, medical history, duration, and severity of symptoms.
|
33 |
+
- Use step-by-step logic to rule in or rule out conditions.
|
34 |
+
- Reflect on diagnostic uncertainty and suggest further assessments if needed.
|
35 |
+
|
36 |
+
In the <solution> section, provide your most likely diagnosis or clinical assessment along with the rationale. Include brief suggestions for potential next steps like labs, imaging, or referrals if appropriate.
|
37 |
+
|
38 |
+
IMPORTANT: When referencing lab values or pathophysiological mechanisms, use LaTeX formatting for clarity. Use $...$ for inline and $$...$$ for block-level expressions.
|
39 |
+
|
40 |
+
Now, please analyze and respond to the following case:
|
41 |
+
"""
|
42 |
+
|
43 |
prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}"
|
44 |
+
for message in history_state:
|
45 |
+
if message["role"] == "user":
|
46 |
+
prompt += f"{start_tag}user{sep_tag}{message['content']}{end_tag}"
|
47 |
+
elif message["role"] == "assistant" and message["content"]:
|
48 |
+
prompt += f"{start_tag}assistant{sep_tag}{message['content']}{end_tag}"
|
49 |
prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}"
|
50 |
|
51 |
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
52 |
+
|
53 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
|
54 |
+
|
55 |
generation_kwargs = {
|
56 |
+
"input_ids": inputs["input_ids"],
|
57 |
+
"attention_mask": inputs["attention_mask"],
|
58 |
"max_new_tokens": int(max_tokens),
|
59 |
"do_sample": True,
|
60 |
+
"temperature": float(temperature),
|
61 |
"top_k": int(top_k),
|
62 |
+
"top_p": float(top_p),
|
63 |
+
"repetition_penalty": float(repetition_penalty),
|
64 |
+
"streamer": streamer,
|
65 |
}
|
66 |
|
67 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
68 |
+
thread.start()
|
69 |
|
70 |
assistant_response = ""
|
71 |
new_history = history_state + [
|
72 |
{"role": "user", "content": user_message},
|
73 |
{"role": "assistant", "content": ""}
|
74 |
]
|
75 |
+
|
76 |
+
for new_token in streamer:
|
77 |
+
cleaned_token = new_token.replace("<|im_start|>", "").replace("<|im_sep|>", "").replace("<|im_end|>", "")
|
78 |
+
assistant_response += cleaned_token
|
79 |
+
new_history[-1]["content"] = assistant_response.strip()
|
80 |
yield new_history, new_history
|
81 |
|
82 |
yield new_history, new_history
|
83 |
|
84 |
+
# Updated example cases for medical diagnostics
|
85 |
example_messages = {
|
86 |
+
"Chest Pain": "A 58-year-old man presents with chest pain that started 20 minutes ago while climbing stairs. He describes it as a heavy pressure in the center of his chest, radiating to his left arm. He has a history of hypertension and smoking. What is the likely diagnosis?",
|
87 |
+
"Shortness of Breath": "A 34-year-old woman presents with 3 days of worsening shortness of breath, low-grade fever, and a dry cough. She denies chest pain or recent travel. Pulse oximetry is 91% on room air.",
|
88 |
+
"Abdominal Pain": "A 22-year-old female presents with lower right quadrant abdominal pain, nausea, and fever. The pain started around the umbilicus and migrated to the right lower quadrant over the past 12 hours.",
|
89 |
+
"Pediatric Fever": "A 2-year-old child has a fever of 39.5°C, irritability, and a rash on the trunk and arms. The child received all standard vaccinations and has no sick contacts. What should be considered in the differential diagnosis?"
|
90 |
}
|
91 |
|
92 |
+
# Custom CSS
|
93 |
+
css = """
|
94 |
+
.markdown-body .katex {
|
95 |
+
font-size: 1.2em;
|
96 |
+
}
|
97 |
+
.markdown-body .katex-display {
|
98 |
+
margin: 1em 0;
|
99 |
+
overflow-x: auto;
|
100 |
+
overflow-y: hidden;
|
101 |
+
}
|
102 |
+
"""
|
103 |
+
|
104 |
+
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
105 |
+
gr.Markdown(
|
106 |
+
"""
|
107 |
+
# Medical Diagnosis Assistant
|
108 |
+
This AI assistant uses structured reasoning to evaluate clinical cases and assist with diagnostic decision-making. Includes LaTeX support for medical calculations.
|
109 |
+
"""
|
110 |
+
)
|
111 |
+
|
112 |
+
gr.HTML("""
|
113 |
+
<script>
|
114 |
+
if (typeof window.MathJax === 'undefined') {
|
115 |
+
const script = document.createElement('script');
|
116 |
+
script.src = 'https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/MathJax.js?config=TeX-MML-AM_CHTML';
|
117 |
+
script.async = true;
|
118 |
+
document.head.appendChild(script);
|
119 |
+
window.MathJax = {
|
120 |
+
tex2jax: {
|
121 |
+
inlineMath: [['$', '$']],
|
122 |
+
displayMath: [['$$', '$$']],
|
123 |
+
processEscapes: true
|
124 |
+
},
|
125 |
+
showProcessingMessages: false,
|
126 |
+
messageStyle: 'none'
|
127 |
+
};
|
128 |
+
}
|
129 |
+
|
130 |
+
function rerender() {
|
131 |
+
if (window.MathJax && window.MathJax.Hub) {
|
132 |
+
window.MathJax.Hub.Queue(['Typeset', window.MathJax.Hub]);
|
133 |
+
}
|
134 |
+
}
|
135 |
+
setInterval(rerender, 1000);
|
136 |
+
</script>
|
137 |
+
""")
|
138 |
|
139 |
history_state = gr.State([])
|
140 |
+
|
141 |
with gr.Row():
|
142 |
with gr.Column(scale=1):
|
143 |
gr.Markdown("### Settings")
|
144 |
+
max_tokens_slider = gr.Slider(64, 32768, step=1024, value=4096, label="Max Tokens")
|
145 |
with gr.Accordion("Advanced Settings", open=False):
|
146 |
temperature_slider = gr.Slider(0.1, 2.0, value=0.8, label="Temperature")
|
147 |
top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k")
|
148 |
top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p")
|
149 |
repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.0, label="Repetition Penalty")
|
150 |
+
|
151 |
with gr.Column(scale=4):
|
152 |
+
chatbot = gr.Chatbot(label="Chat", render_markdown=True, type="messages", elem_id="chatbot", show_copy_button=True)
|
153 |
with gr.Row():
|
154 |
+
user_input = gr.Textbox(label="Describe patient symptoms...", placeholder="Type a clinical case here...", scale=3)
|
155 |
submit_button = gr.Button("Send", variant="primary", scale=1)
|
156 |
clear_button = gr.Button("Clear", scale=1)
|
157 |
+
gr.Markdown("**Try these example cases:**")
|
158 |
with gr.Row():
|
159 |
+
example1_button = gr.Button("Chest Pain")
|
160 |
+
example2_button = gr.Button("Shortness of Breath")
|
161 |
+
example3_button = gr.Button("Abdominal Pain")
|
162 |
+
example4_button = gr.Button("Pediatric Fever")
|
163 |
|
164 |
submit_button.click(
|
165 |
fn=generate_response,
|
166 |
inputs=[user_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider, history_state],
|
167 |
outputs=[chatbot, history_state]
|
168 |
+
).then(
|
169 |
+
fn=lambda: gr.update(value=""),
|
170 |
+
inputs=None,
|
171 |
+
outputs=user_input
|
172 |
+
)
|
173 |
+
|
174 |
+
clear_button.click(
|
175 |
+
fn=lambda: ([], []),
|
176 |
+
inputs=None,
|
177 |
+
outputs=[chatbot, history_state]
|
178 |
+
)
|
179 |
|
180 |
+
example1_button.click(fn=lambda: gr.update(value=example_messages["Chest Pain"]), inputs=None, outputs=user_input)
|
181 |
+
example2_button.click(fn=lambda: gr.update(value=example_messages["Shortness of Breath"]), inputs=None, outputs=user_input)
|
182 |
+
example3_button.click(fn=lambda: gr.update(value=example_messages["Abdominal Pain"]), inputs=None, outputs=user_input)
|
183 |
+
example4_button.click(fn=lambda: gr.update(value=example_messages["Pediatric Fever"]), inputs=None, outputs=user_input)
|
184 |
|
185 |
demo.launch(ssr_mode=False)
|
|
|
|
|
|