Create gradio_app.py
Browse files- gradio_app.py +179 -0
gradio_app.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pickle
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
# --- 1. Load Disease Prediction Model ---
|
6 |
+
with open("disease_model.pkl", "rb") as f:
|
7 |
+
model = pickle.load(f)
|
8 |
+
|
9 |
+
with open("symptoms.pkl", "rb") as f:
|
10 |
+
all_symptoms = pickle.load(f)
|
11 |
+
|
12 |
+
# Preprocess symptoms
|
13 |
+
all_symptoms = sorted(all_symptoms)
|
14 |
+
display_symptoms = [s.replace('_', ' ').title() for s in all_symptoms]
|
15 |
+
label_to_symptom = dict(zip(display_symptoms, all_symptoms))
|
16 |
+
|
17 |
+
# --- 2. Medical Knowledge Base ---
|
18 |
+
MEDICAL_KNOWLEDGE = {
|
19 |
+
|
20 |
+
"migraine": [
|
21 |
+
"For migraines: (1) Rest in dark room (2) OTC pain relievers (ibuprofen/acetaminophen) (3) Apply cold compress (4) Consult neurologist if frequent",
|
22 |
+
"Migraine treatment options include triptans (prescription) and caffeine. Avoid triggers like bright lights or strong smells."
|
23 |
+
],
|
24 |
+
|
25 |
+
"allergy": [
|
26 |
+
"Allergy management: (1) Antihistamines (cetirizine/loratadine) (2) Nasal sprays (3) Allergy shots (immunotherapy) for severe cases",
|
27 |
+
"For food allergies: Strict avoidance, carry epinephrine auto-injector (EpiPen), read food labels carefully"
|
28 |
+
],
|
29 |
+
"cold": [
|
30 |
+
"Treat colds with rest, fluids, and OTC pain relievers. See doctor if fever lasts >3 days",
|
31 |
+
"Most colds resolve in 7-10 days. Use decongestants for nasal congestion"
|
32 |
+
],
|
33 |
+
"headache": [
|
34 |
+
"For headaches: Hydrate, rest, and use OTC pain relievers sparingly",
|
35 |
+
"Persistent headaches require medical evaluation - consult your doctor"
|
36 |
+
],
|
37 |
+
"fever": [
|
38 |
+
"For fever: Rest, fluids, and acetaminophen/ibuprofen. Seek help if >39°C or lasts >3 days",
|
39 |
+
"High fever warning: Seek emergency care if fever >40°C or with stiff neck"
|
40 |
+
]
|
41 |
+
}
|
42 |
+
|
43 |
+
SPECIAL_RESPONSES = {
|
44 |
+
"general approaches": "I can provide specific guidance for: allergies, migraines, colds, fever, back pain, rashes. What condition are you asking about?",
|
45 |
+
"consult a doctor": "For these symptoms, seek medical care: severe pain, difficulty breathing, sudden weakness, high fever (>103°F), or symptoms lasting >7 days"
|
46 |
+
}
|
47 |
+
|
48 |
+
def get_medical_response(user_query):
|
49 |
+
user_query = user_query.lower()
|
50 |
+
|
51 |
+
# First check for special cases
|
52 |
+
for phrase, response in SPECIAL_RESPONSES.items():
|
53 |
+
if phrase in user_query:
|
54 |
+
return response
|
55 |
+
|
56 |
+
# Then check medical conditions
|
57 |
+
for condition, responses in MEDICAL_KNOWLEDGE.items():
|
58 |
+
if condition in user_query:
|
59 |
+
return np.random.choice(responses)
|
60 |
+
|
61 |
+
# Final improvement - suggest related conditions
|
62 |
+
related = [cond for cond in MEDICAL_KNOWLEDGE.keys() if cond in user_query]
|
63 |
+
if related:
|
64 |
+
return f"Are you asking about {', '.join(related)}? {np.random.choice(MEDICAL_KNOWLEDGE[related[0]])}"
|
65 |
+
|
66 |
+
return "I can advise on: " + ", ".join(MEDICAL_KNOWLEDGE.keys()) + ". Please be more specific."
|
67 |
+
|
68 |
+
# --- 3. Disease Prediction Function ---
|
69 |
+
def predict_disease(selected_labels):
|
70 |
+
if not selected_labels or len(selected_labels) < 4:
|
71 |
+
return "⚠️ Please select at least 4 symptoms for accurate results."
|
72 |
+
|
73 |
+
user_symptoms = [label_to_symptom[label] for label in selected_labels]
|
74 |
+
input_vector = [1 if symptom in user_symptoms else 0 for symptom in all_symptoms]
|
75 |
+
input_vector = np.array([input_vector])
|
76 |
+
probas = model.predict_proba(input_vector)[0]
|
77 |
+
max_proba = np.max(probas)
|
78 |
+
predicted = model.classes_[np.argmax(probas)]
|
79 |
+
|
80 |
+
sorted_indices = np.argsort(probas)[::-1]
|
81 |
+
top_diseases = [
|
82 |
+
f"<b>{i+1}. {model.classes_[idx]}</b> — {probas[idx]*100:.1f}%"
|
83 |
+
for i, idx in enumerate(sorted_indices[:3])
|
84 |
+
]
|
85 |
+
|
86 |
+
prediction_result = (
|
87 |
+
f"<div style='background: #001a33; padding: 15px; border-radius: 8px; margin-bottom: 15px;'>"
|
88 |
+
f"<h3 style='color: #4fc3f7; margin-top: 0;'>🩺 Predicted Disease</h3>"
|
89 |
+
f"<p style='font-size: 18px; color: white;'>{predicted} <span style='color: #4fc3f7'>({max_proba*100:.1f}% confidence)</span></p>"
|
90 |
+
"</div>"
|
91 |
+
"<div style='background: #001a33; padding: 15px; border-radius: 8px;'>"
|
92 |
+
"<h3 style='color: #4fc3f7; margin-top: 0;'>🔍 Top 3 Possible Diseases</h3>"
|
93 |
+
"<ul style='color: white; padding-left: 20px;'>" +
|
94 |
+
"".join([f"<li>{d}</li>" for d in top_diseases]) +
|
95 |
+
"</ul>"
|
96 |
+
"</div>"
|
97 |
+
)
|
98 |
+
return prediction_result
|
99 |
+
|
100 |
+
# --- 4. Chat Responder ---
|
101 |
+
def chatbot_respond(message, chat_history):
|
102 |
+
response = get_medical_response(message)
|
103 |
+
return chat_history + [(message, response)], ""
|
104 |
+
|
105 |
+
# --- 5. UI Setup ---
|
106 |
+
custom_css = """
|
107 |
+
:root {
|
108 |
+
--primary: #4fc3f7;
|
109 |
+
--secondary: #001a33;
|
110 |
+
--text: #ffffff;
|
111 |
+
--bg: #0a192f;
|
112 |
+
--card-bg: #0a2342;
|
113 |
+
--error: #ff6b6b;
|
114 |
+
}
|
115 |
+
body, .gradio-container {
|
116 |
+
background: var(--bg) !important;
|
117 |
+
color: var(--text) !important;
|
118 |
+
font-family: 'Segoe UI', Roboto, sans-serif;
|
119 |
+
}
|
120 |
+
/* [Keep all your existing CSS styles] */
|
121 |
+
"""
|
122 |
+
|
123 |
+
with gr.Blocks(css=custom_css) as demo:
|
124 |
+
gr.Markdown("""
|
125 |
+
<div style="text-align: center; margin-bottom: 20px;">
|
126 |
+
<h1 style="margin-bottom: 5px;">🧬 Medical Diagnosis Assistant</h1>
|
127 |
+
<p style="color: #4fc3f7; font-size: 16px;">Select symptoms for diagnosis and get medical advice</p>
|
128 |
+
</div>
|
129 |
+
""")
|
130 |
+
|
131 |
+
with gr.Row(equal_height=True):
|
132 |
+
with gr.Column(scale=1, min_width=300):
|
133 |
+
gr.Markdown("### 🔍 Symptom Checker")
|
134 |
+
symptoms_input = gr.CheckboxGroup(
|
135 |
+
choices=display_symptoms,
|
136 |
+
label="Select your symptoms:",
|
137 |
+
interactive=True
|
138 |
+
)
|
139 |
+
predict_btn = gr.Button("Analyze Symptoms", variant="primary")
|
140 |
+
prediction_output = gr.Markdown(
|
141 |
+
label="Diagnosis Results",
|
142 |
+
value="Your results will appear here..."
|
143 |
+
)
|
144 |
+
|
145 |
+
with gr.Column(scale=1, min_width=400):
|
146 |
+
gr.Markdown("### 💬 Medical Advisor")
|
147 |
+
chatbot = gr.Chatbot(
|
148 |
+
label="Chat with Medical Advisor",
|
149 |
+
show_label=False,
|
150 |
+
bubble_full_width=False
|
151 |
+
)
|
152 |
+
with gr.Row():
|
153 |
+
user_input = gr.Textbox(
|
154 |
+
placeholder="Ask about symptoms or treatments...",
|
155 |
+
label="",
|
156 |
+
show_label=False,
|
157 |
+
container=False,
|
158 |
+
scale=7
|
159 |
+
)
|
160 |
+
send_btn = gr.Button("Send", scale=1, min_width=80)
|
161 |
+
|
162 |
+
# Event handlers
|
163 |
+
predict_btn.click(
|
164 |
+
fn=predict_disease,
|
165 |
+
inputs=symptoms_input,
|
166 |
+
outputs=prediction_output
|
167 |
+
)
|
168 |
+
send_btn.click(
|
169 |
+
fn=chatbot_respond,
|
170 |
+
inputs=[user_input, chatbot],
|
171 |
+
outputs=[chatbot, user_input]
|
172 |
+
)
|
173 |
+
user_input.submit(
|
174 |
+
fn=chatbot_respond,
|
175 |
+
inputs=[user_input, chatbot],
|
176 |
+
outputs=[chatbot, user_input]
|
177 |
+
)
|
178 |
+
|
179 |
+
demo.launch()
|