Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,322 Bytes
fa7e3c5 afa4eb5 fa7e3c5 afa4eb5 fa7e3c5 afa4eb5 6d70605 fa7e3c5 afa4eb5 580e705 afa4eb5 b5ca495 afa4eb5 6d70605 afa4eb5 6d70605 fa7e3c5 afa4eb5 fa7e3c5 afa4eb5 fa7e3c5 afa4eb5 fa7e3c5 1794ce2 afa4eb5 fa7e3c5 afa4eb5 fa7e3c5 afa4eb5 fa7e3c5 afa4eb5 fa7e3c5 afa4eb5 fa7e3c5 afa4eb5 fa7e3c5 afa4eb5 fa7e3c5 580e705 afa4eb5 fa7e3c5 afa4eb5 fa7e3c5 580e705 afa4eb5 fa7e3c5 afa4eb5 fa7e3c5 afa4eb5 fa7e3c5 afa4eb5 fa7e3c5 afa4eb5 fa7e3c5 6d70605 fa7e3c5 afa4eb5 fa7e3c5 afa4eb5 fa7e3c5 692b14d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
import re
phi4_model_path = "Intelligent-Internet/II-Medical-8B"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
phi4_model = AutoModelForCausalLM.from_pretrained(phi4_model_path, device_map="auto", torch_dtype="auto")
phi4_tokenizer = AutoTokenizer.from_pretrained(phi4_model_path)
@spaces.GPU(duration=60)
def generate_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history_state):
if not user_message.strip():
return history_state, history_state
model = phi4_model
tokenizer = phi4_tokenizer
start_tag = "<|im_start|>"
sep_tag = "<|im_sep|>"
end_tag = "<|im_end|>"
system_message = """You are a highly knowledgeable and thoughtful AI medical assistant. Your primary role is to assist with diagnostic reasoning by evaluating patient symptoms, medical history, and relevant clinical context.
Structure your response into two main sections using the following format: <think> {Thought section} </think> {Solution section}.
In the <think> section, use structured clinical reasoning to:
- Identify possible differential diagnoses based on the given symptoms.
- Consider risk factors, medical history, duration, and severity of symptoms.
- Use step-by-step logic to rule in or rule out conditions.
- Reflect on diagnostic uncertainty and suggest further assessments if needed.
In the <solution> section, provide your most likely diagnosis or clinical assessment along with the rationale. Include brief suggestions for potential next steps like labs, imaging, or referrals if appropriate.
IMPORTANT: When referencing lab values or pathophysiological mechanisms, use LaTeX formatting for clarity. Use $...$ for inline and $$...$$ for block-level expressions.
Now, please analyze and respond to the following case:
"""
prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}"
for message in history_state:
if message["role"] == "user":
prompt += f"{start_tag}user{sep_tag}{message['content']}{end_tag}"
elif message["role"] == "assistant" and message["content"]:
prompt += f"{start_tag}assistant{sep_tag}{message['content']}{end_tag}"
prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
generation_kwargs = {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"max_new_tokens": int(max_tokens),
"do_sample": True,
"temperature": float(temperature),
"top_k": int(top_k),
"top_p": float(top_p),
"repetition_penalty": float(repetition_penalty),
"streamer": streamer,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
assistant_response = ""
new_history = history_state + [
{"role": "user", "content": user_message},
{"role": "assistant", "content": ""}
]
for new_token in streamer:
cleaned_token = new_token.replace("<|im_start|>", "").replace("<|im_sep|>", "").replace("<|im_end|>", "")
assistant_response += cleaned_token
new_history[-1]["content"] = assistant_response.strip()
yield new_history, new_history
yield new_history, new_history
# Updated example cases for medical diagnostics
example_messages = {
"Chest Pain": "A 58-year-old man presents with chest pain that started 20 minutes ago while climbing stairs. He describes it as a heavy pressure in the center of his chest, radiating to his left arm. He has a history of hypertension and smoking. What is the likely diagnosis?",
"Shortness of Breath": "A 34-year-old woman presents with 3 days of worsening shortness of breath, low-grade fever, and a dry cough. She denies chest pain or recent travel. Pulse oximetry is 91% on room air.",
"Abdominal Pain": "A 22-year-old female presents with lower right quadrant abdominal pain, nausea, and fever. The pain started around the umbilicus and migrated to the right lower quadrant over the past 12 hours.",
"Pediatric Fever": "A 2-year-old child has a fever of 39.5°C, irritability, and a rash on the trunk and arms. The child received all standard vaccinations and has no sick contacts. What should be considered in the differential diagnosis?"
}
# Custom CSS
css = """
.markdown-body .katex {
font-size: 1.2em;
}
.markdown-body .katex-display {
margin: 1em 0;
overflow-x: auto;
overflow-y: hidden;
}
"""
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
gr.Markdown(
"""
# Medical Diagnosis Assistant
This AI assistant uses structured reasoning to evaluate clinical cases and assist with diagnostic decision-making. Includes LaTeX support for medical calculations.
"""
)
gr.HTML("""
<script>
if (typeof window.MathJax === 'undefined') {
const script = document.createElement('script');
script.src = 'https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/MathJax.js?config=TeX-MML-AM_CHTML';
script.async = true;
document.head.appendChild(script);
window.MathJax = {
tex2jax: {
inlineMath: [['$', '$']],
displayMath: [['$$', '$$']],
processEscapes: true
},
showProcessingMessages: false,
messageStyle: 'none'
};
}
function rerender() {
if (window.MathJax && window.MathJax.Hub) {
window.MathJax.Hub.Queue(['Typeset', window.MathJax.Hub]);
}
}
setInterval(rerender, 1000);
</script>
""")
history_state = gr.State([])
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Settings")
max_tokens_slider = gr.Slider(64, 32768, step=1024, value=4096, label="Max Tokens")
with gr.Accordion("Advanced Settings", open=False):
temperature_slider = gr.Slider(0.1, 2.0, value=0.8, label="Temperature")
top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k")
top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p")
repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.0, label="Repetition Penalty")
with gr.Column(scale=4):
chatbot = gr.Chatbot(label="Chat", render_markdown=True, type="messages", elem_id="chatbot", show_copy_button=True)
with gr.Row():
user_input = gr.Textbox(label="Describe patient symptoms...", placeholder="Type a clinical case here...", scale=3)
submit_button = gr.Button("Send", variant="primary", scale=1)
clear_button = gr.Button("Clear", scale=1)
gr.Markdown("**Try these example cases:**")
with gr.Row():
example1_button = gr.Button("Chest Pain")
example2_button = gr.Button("Shortness of Breath")
example3_button = gr.Button("Abdominal Pain")
example4_button = gr.Button("Pediatric Fever")
submit_button.click(
fn=generate_response,
inputs=[user_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider, history_state],
outputs=[chatbot, history_state]
).then(
fn=lambda: gr.update(value=""),
inputs=None,
outputs=user_input
)
clear_button.click(
fn=lambda: ([], []),
inputs=None,
outputs=[chatbot, history_state]
)
example1_button.click(fn=lambda: gr.update(value=example_messages["Chest Pain"]), inputs=None, outputs=user_input)
example2_button.click(fn=lambda: gr.update(value=example_messages["Shortness of Breath"]), inputs=None, outputs=user_input)
example3_button.click(fn=lambda: gr.update(value=example_messages["Abdominal Pain"]), inputs=None, outputs=user_input)
example4_button.click(fn=lambda: gr.update(value=example_messages["Pediatric Fever"]), inputs=None, outputs=user_input)
demo.launch(ssr_mode=False)
|