Spaces:
Sleeping
Sleeping
File size: 7,437 Bytes
fa7e3c5 afa4eb5 fa7e3c5 afa4eb5 fa7e3c5 afa4eb5 6d70605 fa7e3c5 196c072 afa4eb5 580e705 196c072 afa4eb5 196c072 b5ca495 196c072 afa4eb5 196c072 afa4eb5 196c072 afa4eb5 6d70605 afa4eb5 6d70605 fa7e3c5 afa4eb5 fa7e3c5 afa4eb5 fa7e3c5 1794ce2 afa4eb5 fa7e3c5 afa4eb5 fa7e3c5 afa4eb5 fa7e3c5 afa4eb5 fa7e3c5 196c072 fa7e3c5 196c072 fa7e3c5 afa4eb5 196c072 afa4eb5 fa7e3c5 580e705 afa4eb5 fa7e3c5 afa4eb5 fa7e3c5 580e705 196c072 fa7e3c5 196c072 fa7e3c5 196c072 fa7e3c5 196c072 fa7e3c5 196c072 fa7e3c5 196c072 fa7e3c5 6d70605 fa7e3c5 afa4eb5 196c072 fa7e3c5 196c072 fa7e3c5 692b14d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
phi4_model_path = "Intelligent-Internet/II-Medical-8B"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
phi4_model = AutoModelForCausalLM.from_pretrained(phi4_model_path, device_map="auto", torch_dtype="auto")
phi4_tokenizer = AutoTokenizer.from_pretrained(phi4_model_path)
@spaces.GPU(duration=60)
def generate_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history_state):
if not user_message.strip():
return history_state, history_state
model = phi4_model
tokenizer = phi4_tokenizer
start_tag = "<|im_start|>"
sep_tag = "<|im_sep|>"
end_tag = "<|im_end|>"
system_message = """You are a medical assistant AI designed to help diagnose symptoms, explain possible conditions, and recommend next steps. You must be cautious, thorough, and explain medical reasoning step-by-step. Structure your answer in two sections:
<think> In this section, reason through the symptoms by considering patient history, differential diagnoses, relevant physiological mechanisms, and possible investigations. Explain your thought process step-by-step. </think>
In the Solution section, summarize your working diagnosis, differential options, and suggest what to do next (e.g., tests, referral, lifestyle changes). Always clarify that this is not a replacement for a licensed medical professional.
Use LaTeX for any formulas or values (e.g., $\\text{BMI} = \\frac{\\text{weight (kg)}}{\\text{height (m)}^2}$).
Now, analyze the following case:"""
prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}"
for message in history_state:
if message["role"] == "user":
prompt += f"{start_tag}user{sep_tag}{message['content']}{end_tag}"
elif message["role"] == "assistant" and message["content"]:
prompt += f"{start_tag}assistant{sep_tag}{message['content']}{end_tag}"
prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
generation_kwargs = {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"max_new_tokens": int(max_tokens),
"do_sample": True,
"temperature": float(temperature),
"top_k": int(top_k),
"top_p": float(top_p),
"repetition_penalty": float(repetition_penalty),
"streamer": streamer,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
assistant_response = ""
new_history = history_state + [
{"role": "user", "content": user_message},
{"role": "assistant", "content": ""}
]
for new_token in streamer:
cleaned_token = new_token.replace("<|im_start|>", "").replace("<|im_sep|>", "").replace("<|im_end|>", "")
assistant_response += cleaned_token
new_history[-1]["content"] = assistant_response.strip()
yield new_history, new_history
yield new_history, new_history
example_messages = {
"Headache case": "A 35-year-old female presents with a throbbing headache, nausea, and sensitivity to light. It started on one side of her head and worsens with activity. No prior trauma.",
"Chest pain": "A 58-year-old male presents with chest tightness radiating to his left arm, shortness of breath, and sweating. Symptoms began while climbing stairs.",
"Abdominal pain": "A 24-year-old complains of right lower quadrant abdominal pain, nausea, and mild fever. The pain started around the belly button and migrated.",
"BMI calculation": "A patient weighs 85 kg and is 1.75 meters tall. Calculate the BMI and interpret whether it's underweight, normal, overweight, or obese."
}
css = """
.markdown-body .katex {
font-size: 1.2em;
}
.markdown-body .katex-display {
margin: 1em 0;
overflow-x: auto;
overflow-y: hidden;
}
"""
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
gr.Markdown("# Medical Diagnostic Assistant\nThis AI assistant helps analyze symptoms and provide preliminary diagnostic reasoning using LaTeX-rendered medical formulas where needed.")
gr.HTML("""
<script>
if (typeof window.MathJax === 'undefined') {
const script = document.createElement('script');
script.src = 'https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/MathJax.js?config=TeX-MML-AM_CHTML';
script.async = true;
document.head.appendChild(script);
window.MathJax = {
tex2jax: {
inlineMath: [['$', '$']],
displayMath: [['$$', '$$']],
processEscapes: true
},
showProcessingMessages: false,
messageStyle: 'none'
};
}
function rerender() {
if (window.MathJax && window.MathJax.Hub) {
window.MathJax.Hub.Queue(['Typeset', window.MathJax.Hub]);
}
}
setInterval(rerender, 1000);
</script>
""")
history_state = gr.State([])
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Settings")
max_tokens_slider = gr.Slider(64, 32768, step=1024, value=4096, label="Max Tokens")
with gr.Accordion("Advanced Settings", open=False):
temperature_slider = gr.Slider(0.1, 2.0, value=0.8, label="Temperature")
top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k")
top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p")
repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.0, label="Repetition Penalty")
with gr.Column(scale=4):
chatbot = gr.Chatbot(label="Chat", render_markdown=True, type="messages", show_copy_button=True)
with gr.Row():
user_input = gr.Textbox(label="Describe symptoms or ask a medical question", placeholder="Type your message here...", scale=3)
submit_button = gr.Button("Send", variant="primary", scale=1)
clear_button = gr.Button("Clear", scale=1)
gr.Markdown("**Try these examples:**")
with gr.Row():
example1 = gr.Button("Headache case")
example2 = gr.Button("Chest pain")
example3 = gr.Button("Abdominal pain")
example4 = gr.Button("BMI calculation")
submit_button.stream(
fn=generate_response,
inputs=[user_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider, history_state],
outputs=[chatbot, history_state]
).then(
fn=lambda: gr.update(value=""),
inputs=None,
outputs=user_input
)
clear_button.click(fn=lambda: ([], []), inputs=None, outputs=[chatbot, history_state])
example1.click(lambda: gr.update(value=example_messages["Headache case"]), None, user_input)
example2.click(lambda: gr.update(value=example_messages["Chest pain"]), None, user_input)
example3.click(lambda: gr.update(value=example_messages["Abdominal pain"]), None, user_input)
example4.click(lambda: gr.update(value=example_messages["BMI calculation"]), None, user_input)
demo.launch(ssr_mode=False)
|