File size: 8,322 Bytes
fa7e3c5
afa4eb5
 
fa7e3c5
 
afa4eb5
 
 
 
 
fa7e3c5
afa4eb5
 
 
 
6d70605
fa7e3c5
 
afa4eb5
 
 
 
 
 
580e705
afa4eb5
 
 
b5ca495
afa4eb5
 
 
 
 
 
 
 
 
 
 
 
 
6d70605
afa4eb5
 
 
 
 
6d70605
fa7e3c5
 
afa4eb5
fa7e3c5
afa4eb5
fa7e3c5
afa4eb5
 
fa7e3c5
1794ce2
afa4eb5
fa7e3c5
afa4eb5
 
 
fa7e3c5
 
afa4eb5
 
fa7e3c5
 
 
 
 
 
afa4eb5
 
 
 
 
fa7e3c5
 
 
 
afa4eb5
fa7e3c5
afa4eb5
 
 
 
fa7e3c5
 
afa4eb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa7e3c5
580e705
afa4eb5
fa7e3c5
 
 
afa4eb5
fa7e3c5
580e705
 
 
 
afa4eb5
fa7e3c5
afa4eb5
fa7e3c5
afa4eb5
fa7e3c5
 
afa4eb5
fa7e3c5
afa4eb5
 
 
 
fa7e3c5
 
 
6d70605
fa7e3c5
afa4eb5
 
 
 
 
 
 
 
 
 
 
fa7e3c5
afa4eb5
 
 
 
fa7e3c5
692b14d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
import re

phi4_model_path = "Intelligent-Internet/II-Medical-8B"

device = "cuda:0" if torch.cuda.is_available() else "cpu"

phi4_model = AutoModelForCausalLM.from_pretrained(phi4_model_path, device_map="auto", torch_dtype="auto")
phi4_tokenizer = AutoTokenizer.from_pretrained(phi4_model_path)

@spaces.GPU(duration=60)
def generate_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history_state):
    if not user_message.strip():
        return history_state, history_state
        
    model = phi4_model
    tokenizer = phi4_tokenizer
    start_tag = "<|im_start|>"
    sep_tag = "<|im_sep|>"
    end_tag = "<|im_end|>"

    system_message = """You are a highly knowledgeable and thoughtful AI medical assistant. Your primary role is to assist with diagnostic reasoning by evaluating patient symptoms, medical history, and relevant clinical context. 

Structure your response into two main sections using the following format: <think> {Thought section} </think> {Solution section}.

In the <think> section, use structured clinical reasoning to:
- Identify possible differential diagnoses based on the given symptoms.
- Consider risk factors, medical history, duration, and severity of symptoms.
- Use step-by-step logic to rule in or rule out conditions.
- Reflect on diagnostic uncertainty and suggest further assessments if needed.

In the <solution> section, provide your most likely diagnosis or clinical assessment along with the rationale. Include brief suggestions for potential next steps like labs, imaging, or referrals if appropriate.

IMPORTANT: When referencing lab values or pathophysiological mechanisms, use LaTeX formatting for clarity. Use $...$ for inline and $$...$$ for block-level expressions.

Now, please analyze and respond to the following case:
"""
    
    prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}"
    for message in history_state:
        if message["role"] == "user":
            prompt += f"{start_tag}user{sep_tag}{message['content']}{end_tag}"
        elif message["role"] == "assistant" and message["content"]:
            prompt += f"{start_tag}assistant{sep_tag}{message['content']}{end_tag}"
    prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}"

    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)

    generation_kwargs = {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "max_new_tokens": int(max_tokens),
        "do_sample": True,
        "temperature": float(temperature),
        "top_k": int(top_k),
        "top_p": float(top_p),
        "repetition_penalty": float(repetition_penalty),
        "streamer": streamer,
    }

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    assistant_response = ""
    new_history = history_state + [
        {"role": "user", "content": user_message},
        {"role": "assistant", "content": ""}
    ]
    
    for new_token in streamer:
        cleaned_token = new_token.replace("<|im_start|>", "").replace("<|im_sep|>", "").replace("<|im_end|>", "")
        assistant_response += cleaned_token
        new_history[-1]["content"] = assistant_response.strip()
        yield new_history, new_history

    yield new_history, new_history

# Updated example cases for medical diagnostics
example_messages = {
    "Chest Pain": "A 58-year-old man presents with chest pain that started 20 minutes ago while climbing stairs. He describes it as a heavy pressure in the center of his chest, radiating to his left arm. He has a history of hypertension and smoking. What is the likely diagnosis?",
    "Shortness of Breath": "A 34-year-old woman presents with 3 days of worsening shortness of breath, low-grade fever, and a dry cough. She denies chest pain or recent travel. Pulse oximetry is 91% on room air.",
    "Abdominal Pain": "A 22-year-old female presents with lower right quadrant abdominal pain, nausea, and fever. The pain started around the umbilicus and migrated to the right lower quadrant over the past 12 hours.",
    "Pediatric Fever": "A 2-year-old child has a fever of 39.5°C, irritability, and a rash on the trunk and arms. The child received all standard vaccinations and has no sick contacts. What should be considered in the differential diagnosis?"
}

# Custom CSS
css = """
.markdown-body .katex { 
    font-size: 1.2em; 
}
.markdown-body .katex-display { 
    margin: 1em 0; 
    overflow-x: auto;
    overflow-y: hidden;
}
"""

with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
    gr.Markdown(
        """
        # Medical Diagnosis Assistant
        This AI assistant uses structured reasoning to evaluate clinical cases and assist with diagnostic decision-making. Includes LaTeX support for medical calculations.
        """
    )

    gr.HTML("""
    <script>
    if (typeof window.MathJax === 'undefined') {
        const script = document.createElement('script');
        script.src = 'https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/MathJax.js?config=TeX-MML-AM_CHTML';
        script.async = true;
        document.head.appendChild(script);
        window.MathJax = {
            tex2jax: {
                inlineMath: [['$', '$']],
                displayMath: [['$$', '$$']],
                processEscapes: true
            },
            showProcessingMessages: false,
            messageStyle: 'none'
        };
    }

    function rerender() {
        if (window.MathJax && window.MathJax.Hub) {
            window.MathJax.Hub.Queue(['Typeset', window.MathJax.Hub]);
        }
    }
    setInterval(rerender, 1000);
    </script>
    """)

    history_state = gr.State([])

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### Settings")
            max_tokens_slider = gr.Slider(64, 32768, step=1024, value=4096, label="Max Tokens")
            with gr.Accordion("Advanced Settings", open=False):
                temperature_slider = gr.Slider(0.1, 2.0, value=0.8, label="Temperature")
                top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k")
                top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p")
                repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.0, label="Repetition Penalty")
        
        with gr.Column(scale=4):
            chatbot = gr.Chatbot(label="Chat", render_markdown=True, type="messages", elem_id="chatbot", show_copy_button=True)
            with gr.Row():
                user_input = gr.Textbox(label="Describe patient symptoms...", placeholder="Type a clinical case here...", scale=3)
                submit_button = gr.Button("Send", variant="primary", scale=1)
                clear_button = gr.Button("Clear", scale=1)
            gr.Markdown("**Try these example cases:**")
            with gr.Row():
                example1_button = gr.Button("Chest Pain")
                example2_button = gr.Button("Shortness of Breath")
                example3_button = gr.Button("Abdominal Pain")
                example4_button = gr.Button("Pediatric Fever")

    submit_button.click(
        fn=generate_response,
        inputs=[user_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider, history_state],
        outputs=[chatbot, history_state]
    ).then(
        fn=lambda: gr.update(value=""),
        inputs=None,
        outputs=user_input
    )

    clear_button.click(
        fn=lambda: ([], []),
        inputs=None,
        outputs=[chatbot, history_state]
    )

    example1_button.click(fn=lambda: gr.update(value=example_messages["Chest Pain"]), inputs=None, outputs=user_input)
    example2_button.click(fn=lambda: gr.update(value=example_messages["Shortness of Breath"]), inputs=None, outputs=user_input)
    example3_button.click(fn=lambda: gr.update(value=example_messages["Abdominal Pain"]), inputs=None, outputs=user_input)
    example4_button.click(fn=lambda: gr.update(value=example_messages["Pediatric Fever"]), inputs=None, outputs=user_input)

demo.launch(ssr_mode=False)