File size: 7,607 Bytes
fa7e3c5
afa4eb5
 
fa7e3c5
 
afa4eb5
 
 
 
fa7e3c5
afa4eb5
 
 
f6ac5ae
 
fa7e3c5
f6ac5ae
196c072
afa4eb5
 
 
 
 
580e705
196c072
afa4eb5
196c072
b5ca495
196c072
afa4eb5
196c072
afa4eb5
196c072
afa4eb5
f6ac5ae
6d70605
f6ac5ae
 
 
 
 
 
 
 
 
6d70605
fa7e3c5
 
afa4eb5
fa7e3c5
 
afa4eb5
 
fa7e3c5
1794ce2
afa4eb5
fa7e3c5
afa4eb5
 
 
fa7e3c5
 
afa4eb5
 
fa7e3c5
f6ac5ae
 
afa4eb5
f6ac5ae
 
afa4eb5
 
 
f6ac5ae
 
 
 
 
fa7e3c5
196c072
fa7e3c5
196c072
 
 
 
fa7e3c5
 
afa4eb5
 
 
 
 
 
 
 
 
 
 
 
196c072
afa4eb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa7e3c5
f6ac5ae
 
afa4eb5
fa7e3c5
 
 
afa4eb5
fa7e3c5
580e705
 
 
 
196c072
fa7e3c5
 
196c072
fa7e3c5
 
196c072
fa7e3c5
196c072
 
 
 
fa7e3c5
f6ac5ae
7336213
f6ac5ae
 
 
 
afa4eb5
 
 
 
 
 
f6ac5ae
fa7e3c5
196c072
 
 
 
fa7e3c5
7336213
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread

phi4_model_path = "Intelligent-Internet/II-Medical-8B"

device = "cuda:0" if torch.cuda.is_available() else "cpu"

phi4_model = AutoModelForCausalLM.from_pretrained(phi4_model_path, device_map="auto", torch_dtype="auto")
phi4_tokenizer = AutoTokenizer.from_pretrained(phi4_model_path)

@spaces.GPU(duration=60)
def generate_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history):
    if not user_message.strip():
        return history, history

    model = phi4_model
    tokenizer = phi4_tokenizer
    start_tag = "<|im_start|>"
    sep_tag = "<|im_sep|>"
    end_tag = "<|im_end|>"

    system_message = """You are a medical assistant AI designed to help diagnose symptoms, explain possible conditions, and recommend next steps. You must be cautious, thorough, and explain medical reasoning step-by-step. Structure your answer in two sections: 

<think> In this section, reason through the symptoms by considering patient history, differential diagnoses, relevant physiological mechanisms, and possible investigations. Explain your thought process step-by-step. </think> 

In the Solution section, summarize your working diagnosis, differential options, and suggest what to do next (e.g., tests, referral, lifestyle changes). Always clarify that this is not a replacement for a licensed medical professional.

Use LaTeX for any formulas or values (e.g., $\\text{BMI} = \\frac{\\text{weight (kg)}}{\\text{height (m)}^2}$). 

Now, analyze the following case:"""

    # Build conversation history in the format the model expects
    prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}"
    
    # Convert chat history format from the Gradio Chatbot format to prompt format
    for user_msg, bot_msg in history:
        if user_msg:
            prompt += f"{start_tag}user{sep_tag}{user_msg}{end_tag}"
        if bot_msg:
            prompt += f"{start_tag}assistant{sep_tag}{bot_msg}{end_tag}"
    
    # Add the current user message
    prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}"

    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
    generation_kwargs = {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "max_new_tokens": int(max_tokens),
        "do_sample": True,
        "temperature": float(temperature),
        "top_k": int(top_k),
        "top_p": float(top_p),
        "repetition_penalty": float(repetition_penalty),
        "streamer": streamer,
    }

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    # Create a new history with the current user message
    new_history = history.copy() + [[user_message, ""]]
    
    # Collect the generated response
    assistant_response = ""
    for new_token in streamer:
        cleaned_token = new_token.replace("<|im_start|>", "").replace("<|im_sep|>", "").replace("<|im_end|>", "")
        assistant_response += cleaned_token
        # Update the last message in history with the current response
        new_history[-1][1] = assistant_response.strip()
        
    # Return the updated history
    return new_history, new_history


example_messages = {
    "Headache case": "A 35-year-old female presents with a throbbing headache, nausea, and sensitivity to light. It started on one side of her head and worsens with activity. No prior trauma.",
    "Chest pain": "A 58-year-old male presents with chest tightness radiating to his left arm, shortness of breath, and sweating. Symptoms began while climbing stairs.",
    "Abdominal pain": "A 24-year-old complains of right lower quadrant abdominal pain, nausea, and mild fever. The pain started around the belly button and migrated.",
    "BMI calculation": "A patient weighs 85 kg and is 1.75 meters tall. Calculate the BMI and interpret whether it's underweight, normal, overweight, or obese."
}

css = """
.markdown-body .katex { 
    font-size: 1.2em; 
}
.markdown-body .katex-display { 
    margin: 1em 0; 
    overflow-x: auto;
    overflow-y: hidden;
}
"""

with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
    gr.Markdown("# Medical Diagnostic Assistant\nThis AI assistant helps analyze symptoms and provide preliminary diagnostic reasoning using LaTeX-rendered medical formulas where needed.")

    gr.HTML("""
    <script>
    if (typeof window.MathJax === 'undefined') {
        const script = document.createElement('script');
        script.src = 'https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/MathJax.js?config=TeX-MML-AM_CHTML';
        script.async = true;
        document.head.appendChild(script);
        window.MathJax = {
            tex2jax: {
                inlineMath: [['$', '$']],
                displayMath: [['$$', '$$']],
                processEscapes: true
            },
            showProcessingMessages: false,
            messageStyle: 'none'
        };
    }
    function rerender() {
        if (window.MathJax && window.MathJax.Hub) {
            window.MathJax.Hub.Queue(['Typeset', window.MathJax.Hub]);
        }
    }
    setInterval(rerender, 1000);
    </script>
    """)

    chatbot = gr.Chatbot(label="Chat", render_markdown=True, show_copy_button=True)
    history = gr.State([])

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### Settings")
            max_tokens_slider = gr.Slider(64, 32768, step=1024, value=4096, label="Max Tokens")
            with gr.Accordion("Advanced Settings", open=False):
                temperature_slider = gr.Slider(0.1, 2.0, value=0.8, label="Temperature")
                top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k")
                top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p")
                repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.0, label="Repetition Penalty")

        with gr.Column(scale=4):
            with gr.Row():
                user_input = gr.Textbox(label="Describe symptoms or ask a medical question", placeholder="Type your message here...", scale=3)
                submit_button = gr.Button("Send", variant="primary", scale=1)
                clear_button = gr.Button("Clear", scale=1)
            gr.Markdown("**Try these examples:**")
            with gr.Row():
                example1 = gr.Button("Headache case")
                example2 = gr.Button("Chest pain")
                example3 = gr.Button("Abdominal pain")
                example4 = gr.Button("BMI calculation")

    # Use click instead of stream
    submit_button.click(
        fn=generate_response,
        inputs=[user_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, 
                repetition_penalty_slider, history],
        outputs=[chatbot, history]
    ).then(
        fn=lambda: gr.update(value=""),
        inputs=None,
        outputs=user_input
    )

    clear_button.click(fn=lambda: ([], []), inputs=None, outputs=[chatbot, history])

    example1.click(lambda: gr.update(value=example_messages["Headache case"]), None, user_input)
    example2.click(lambda: gr.update(value=example_messages["Chest pain"]), None, user_input)
    example3.click(lambda: gr.update(value=example_messages["Abdominal pain"]), None, user_input)
    example4.click(lambda: gr.update(value=example_messages["BMI calculation"]), None, user_input)

demo.launch(ssr_mode=False)