VisoLearn's picture
Update app.py
afa4eb5 verified
raw
history blame
8.32 kB
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
import re
phi4_model_path = "Intelligent-Internet/II-Medical-8B"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
phi4_model = AutoModelForCausalLM.from_pretrained(phi4_model_path, device_map="auto", torch_dtype="auto")
phi4_tokenizer = AutoTokenizer.from_pretrained(phi4_model_path)
@spaces.GPU(duration=60)
def generate_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history_state):
if not user_message.strip():
return history_state, history_state
model = phi4_model
tokenizer = phi4_tokenizer
start_tag = "<|im_start|>"
sep_tag = "<|im_sep|>"
end_tag = "<|im_end|>"
system_message = """You are a highly knowledgeable and thoughtful AI medical assistant. Your primary role is to assist with diagnostic reasoning by evaluating patient symptoms, medical history, and relevant clinical context.
Structure your response into two main sections using the following format: <think> {Thought section} </think> {Solution section}.
In the <think> section, use structured clinical reasoning to:
- Identify possible differential diagnoses based on the given symptoms.
- Consider risk factors, medical history, duration, and severity of symptoms.
- Use step-by-step logic to rule in or rule out conditions.
- Reflect on diagnostic uncertainty and suggest further assessments if needed.
In the <solution> section, provide your most likely diagnosis or clinical assessment along with the rationale. Include brief suggestions for potential next steps like labs, imaging, or referrals if appropriate.
IMPORTANT: When referencing lab values or pathophysiological mechanisms, use LaTeX formatting for clarity. Use $...$ for inline and $$...$$ for block-level expressions.
Now, please analyze and respond to the following case:
"""
prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}"
for message in history_state:
if message["role"] == "user":
prompt += f"{start_tag}user{sep_tag}{message['content']}{end_tag}"
elif message["role"] == "assistant" and message["content"]:
prompt += f"{start_tag}assistant{sep_tag}{message['content']}{end_tag}"
prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
generation_kwargs = {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"max_new_tokens": int(max_tokens),
"do_sample": True,
"temperature": float(temperature),
"top_k": int(top_k),
"top_p": float(top_p),
"repetition_penalty": float(repetition_penalty),
"streamer": streamer,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
assistant_response = ""
new_history = history_state + [
{"role": "user", "content": user_message},
{"role": "assistant", "content": ""}
]
for new_token in streamer:
cleaned_token = new_token.replace("<|im_start|>", "").replace("<|im_sep|>", "").replace("<|im_end|>", "")
assistant_response += cleaned_token
new_history[-1]["content"] = assistant_response.strip()
yield new_history, new_history
yield new_history, new_history
# Updated example cases for medical diagnostics
example_messages = {
"Chest Pain": "A 58-year-old man presents with chest pain that started 20 minutes ago while climbing stairs. He describes it as a heavy pressure in the center of his chest, radiating to his left arm. He has a history of hypertension and smoking. What is the likely diagnosis?",
"Shortness of Breath": "A 34-year-old woman presents with 3 days of worsening shortness of breath, low-grade fever, and a dry cough. She denies chest pain or recent travel. Pulse oximetry is 91% on room air.",
"Abdominal Pain": "A 22-year-old female presents with lower right quadrant abdominal pain, nausea, and fever. The pain started around the umbilicus and migrated to the right lower quadrant over the past 12 hours.",
"Pediatric Fever": "A 2-year-old child has a fever of 39.5°C, irritability, and a rash on the trunk and arms. The child received all standard vaccinations and has no sick contacts. What should be considered in the differential diagnosis?"
}
# Custom CSS
css = """
.markdown-body .katex {
font-size: 1.2em;
}
.markdown-body .katex-display {
margin: 1em 0;
overflow-x: auto;
overflow-y: hidden;
}
"""
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
gr.Markdown(
"""
# Medical Diagnosis Assistant
This AI assistant uses structured reasoning to evaluate clinical cases and assist with diagnostic decision-making. Includes LaTeX support for medical calculations.
"""
)
gr.HTML("""
<script>
if (typeof window.MathJax === 'undefined') {
const script = document.createElement('script');
script.src = 'https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/MathJax.js?config=TeX-MML-AM_CHTML';
script.async = true;
document.head.appendChild(script);
window.MathJax = {
tex2jax: {
inlineMath: [['$', '$']],
displayMath: [['$$', '$$']],
processEscapes: true
},
showProcessingMessages: false,
messageStyle: 'none'
};
}
function rerender() {
if (window.MathJax && window.MathJax.Hub) {
window.MathJax.Hub.Queue(['Typeset', window.MathJax.Hub]);
}
}
setInterval(rerender, 1000);
</script>
""")
history_state = gr.State([])
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Settings")
max_tokens_slider = gr.Slider(64, 32768, step=1024, value=4096, label="Max Tokens")
with gr.Accordion("Advanced Settings", open=False):
temperature_slider = gr.Slider(0.1, 2.0, value=0.8, label="Temperature")
top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k")
top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p")
repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.0, label="Repetition Penalty")
with gr.Column(scale=4):
chatbot = gr.Chatbot(label="Chat", render_markdown=True, type="messages", elem_id="chatbot", show_copy_button=True)
with gr.Row():
user_input = gr.Textbox(label="Describe patient symptoms...", placeholder="Type a clinical case here...", scale=3)
submit_button = gr.Button("Send", variant="primary", scale=1)
clear_button = gr.Button("Clear", scale=1)
gr.Markdown("**Try these example cases:**")
with gr.Row():
example1_button = gr.Button("Chest Pain")
example2_button = gr.Button("Shortness of Breath")
example3_button = gr.Button("Abdominal Pain")
example4_button = gr.Button("Pediatric Fever")
submit_button.click(
fn=generate_response,
inputs=[user_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider, history_state],
outputs=[chatbot, history_state]
).then(
fn=lambda: gr.update(value=""),
inputs=None,
outputs=user_input
)
clear_button.click(
fn=lambda: ([], []),
inputs=None,
outputs=[chatbot, history_state]
)
example1_button.click(fn=lambda: gr.update(value=example_messages["Chest Pain"]), inputs=None, outputs=user_input)
example2_button.click(fn=lambda: gr.update(value=example_messages["Shortness of Breath"]), inputs=None, outputs=user_input)
example3_button.click(fn=lambda: gr.update(value=example_messages["Abdominal Pain"]), inputs=None, outputs=user_input)
example4_button.click(fn=lambda: gr.update(value=example_messages["Pediatric Fever"]), inputs=None, outputs=user_input)
demo.launch(ssr_mode=False)