Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
import torch | |
from threading import Thread | |
import re | |
phi4_model_path = "Intelligent-Internet/II-Medical-8B" | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
phi4_model = AutoModelForCausalLM.from_pretrained(phi4_model_path, device_map="auto", torch_dtype="auto") | |
phi4_tokenizer = AutoTokenizer.from_pretrained(phi4_model_path) | |
def generate_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history_state): | |
if not user_message.strip(): | |
return history_state, history_state | |
model = phi4_model | |
tokenizer = phi4_tokenizer | |
start_tag = "<|im_start|>" | |
sep_tag = "<|im_sep|>" | |
end_tag = "<|im_end|>" | |
system_message = """You are a highly knowledgeable and thoughtful AI medical assistant. Your primary role is to assist with diagnostic reasoning by evaluating patient symptoms, medical history, and relevant clinical context. | |
Structure your response into two main sections using the following format: <think> {Thought section} </think> {Solution section}. | |
In the <think> section, use structured clinical reasoning to: | |
- Identify possible differential diagnoses based on the given symptoms. | |
- Consider risk factors, medical history, duration, and severity of symptoms. | |
- Use step-by-step logic to rule in or rule out conditions. | |
- Reflect on diagnostic uncertainty and suggest further assessments if needed. | |
In the <solution> section, provide your most likely diagnosis or clinical assessment along with the rationale. Include brief suggestions for potential next steps like labs, imaging, or referrals if appropriate. | |
IMPORTANT: When referencing lab values or pathophysiological mechanisms, use LaTeX formatting for clarity. Use $...$ for inline and $$...$$ for block-level expressions. | |
Now, please analyze and respond to the following case: | |
""" | |
prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}" | |
for message in history_state: | |
if message["role"] == "user": | |
prompt += f"{start_tag}user{sep_tag}{message['content']}{end_tag}" | |
elif message["role"] == "assistant" and message["content"]: | |
prompt += f"{start_tag}assistant{sep_tag}{message['content']}{end_tag}" | |
prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}" | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) | |
generation_kwargs = { | |
"input_ids": inputs["input_ids"], | |
"attention_mask": inputs["attention_mask"], | |
"max_new_tokens": int(max_tokens), | |
"do_sample": True, | |
"temperature": float(temperature), | |
"top_k": int(top_k), | |
"top_p": float(top_p), | |
"repetition_penalty": float(repetition_penalty), | |
"streamer": streamer, | |
} | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
assistant_response = "" | |
new_history = history_state + [ | |
{"role": "user", "content": user_message}, | |
{"role": "assistant", "content": ""} | |
] | |
for new_token in streamer: | |
cleaned_token = new_token.replace("<|im_start|>", "").replace("<|im_sep|>", "").replace("<|im_end|>", "") | |
assistant_response += cleaned_token | |
new_history[-1]["content"] = assistant_response.strip() | |
yield new_history, new_history | |
yield new_history, new_history | |
# Updated example cases for medical diagnostics | |
example_messages = { | |
"Chest Pain": "A 58-year-old man presents with chest pain that started 20 minutes ago while climbing stairs. He describes it as a heavy pressure in the center of his chest, radiating to his left arm. He has a history of hypertension and smoking. What is the likely diagnosis?", | |
"Shortness of Breath": "A 34-year-old woman presents with 3 days of worsening shortness of breath, low-grade fever, and a dry cough. She denies chest pain or recent travel. Pulse oximetry is 91% on room air.", | |
"Abdominal Pain": "A 22-year-old female presents with lower right quadrant abdominal pain, nausea, and fever. The pain started around the umbilicus and migrated to the right lower quadrant over the past 12 hours.", | |
"Pediatric Fever": "A 2-year-old child has a fever of 39.5°C, irritability, and a rash on the trunk and arms. The child received all standard vaccinations and has no sick contacts. What should be considered in the differential diagnosis?" | |
} | |
# Custom CSS | |
css = """ | |
.markdown-body .katex { | |
font-size: 1.2em; | |
} | |
.markdown-body .katex-display { | |
margin: 1em 0; | |
overflow-x: auto; | |
overflow-y: hidden; | |
} | |
""" | |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: | |
gr.Markdown( | |
""" | |
# Medical Diagnosis Assistant | |
This AI assistant uses structured reasoning to evaluate clinical cases and assist with diagnostic decision-making. Includes LaTeX support for medical calculations. | |
""" | |
) | |
gr.HTML(""" | |
<script> | |
if (typeof window.MathJax === 'undefined') { | |
const script = document.createElement('script'); | |
script.src = 'https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/MathJax.js?config=TeX-MML-AM_CHTML'; | |
script.async = true; | |
document.head.appendChild(script); | |
window.MathJax = { | |
tex2jax: { | |
inlineMath: [['$', '$']], | |
displayMath: [['$$', '$$']], | |
processEscapes: true | |
}, | |
showProcessingMessages: false, | |
messageStyle: 'none' | |
}; | |
} | |
function rerender() { | |
if (window.MathJax && window.MathJax.Hub) { | |
window.MathJax.Hub.Queue(['Typeset', window.MathJax.Hub]); | |
} | |
} | |
setInterval(rerender, 1000); | |
</script> | |
""") | |
history_state = gr.State([]) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### Settings") | |
max_tokens_slider = gr.Slider(64, 32768, step=1024, value=4096, label="Max Tokens") | |
with gr.Accordion("Advanced Settings", open=False): | |
temperature_slider = gr.Slider(0.1, 2.0, value=0.8, label="Temperature") | |
top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k") | |
top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p") | |
repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.0, label="Repetition Penalty") | |
with gr.Column(scale=4): | |
chatbot = gr.Chatbot(label="Chat", render_markdown=True, type="messages", elem_id="chatbot", show_copy_button=True) | |
with gr.Row(): | |
user_input = gr.Textbox(label="Describe patient symptoms...", placeholder="Type a clinical case here...", scale=3) | |
submit_button = gr.Button("Send", variant="primary", scale=1) | |
clear_button = gr.Button("Clear", scale=1) | |
gr.Markdown("**Try these example cases:**") | |
with gr.Row(): | |
example1_button = gr.Button("Chest Pain") | |
example2_button = gr.Button("Shortness of Breath") | |
example3_button = gr.Button("Abdominal Pain") | |
example4_button = gr.Button("Pediatric Fever") | |
submit_button.click( | |
fn=generate_response, | |
inputs=[user_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider, history_state], | |
outputs=[chatbot, history_state] | |
).then( | |
fn=lambda: gr.update(value=""), | |
inputs=None, | |
outputs=user_input | |
) | |
clear_button.click( | |
fn=lambda: ([], []), | |
inputs=None, | |
outputs=[chatbot, history_state] | |
) | |
example1_button.click(fn=lambda: gr.update(value=example_messages["Chest Pain"]), inputs=None, outputs=user_input) | |
example2_button.click(fn=lambda: gr.update(value=example_messages["Shortness of Breath"]), inputs=None, outputs=user_input) | |
example3_button.click(fn=lambda: gr.update(value=example_messages["Abdominal Pain"]), inputs=None, outputs=user_input) | |
example4_button.click(fn=lambda: gr.update(value=example_messages["Pediatric Fever"]), inputs=None, outputs=user_input) | |
demo.launch(ssr_mode=False) | |