matesoft commited on
Commit
0f878f4
·
verified ·
1 Parent(s): 05a7772

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -165
app.py CHANGED
@@ -1,166 +1,10 @@
1
- # -*- coding: utf-8 -*-
2
- """2_preprocessing_test.ipynb
3
-
4
- Automatically generated by Colab.
5
-
6
- Original file is located at
7
- https://colab.research.google.com/drive/10c3x9G9z70J73l0LJDA8_VDZphQmHEZB
8
- """
9
-
10
-
11
  import pandas as pd
12
  import numpy as np
13
- import matplotlib.pyplot as plt
14
- from sklearn.preprocessing import LabelEncoder
15
- import os
16
- from sklearn.model_selection import train_test_split
17
  import pickle
18
- import warnings
19
- warnings.filterwarnings('ignore')
20
-
21
- df1 = pd.read_csv("/content/drive/MyDrive/Google Colab/disease-symptom-prediction/data/dataset.csv")
22
-
23
- print(df1.shape)
24
- df1.head()
25
-
26
- df1.sort_values(by='Disease', inplace=True)
27
- df1.head()
28
-
29
- df1.drop_duplicates(inplace=True)
30
- df1.shape
31
-
32
- df1['Disease'].value_counts()
33
-
34
- df1[df1['Disease']=="Fungal infection"]
35
-
36
- df1.fillna("none", inplace=True)
37
- df1[df1['Disease']=="Fungal infection"]
38
-
39
- df1.columns = df1.columns.str.strip().str.lower()
40
- for col in df1.columns:
41
- df1[col] = df1[col].astype(str).str.strip().str.lower()
42
-
43
-
44
- symptom_cols = [col for col in df1.columns if col.startswith('symptom')]
45
- print(symptom_cols)
46
-
47
- all_symptoms = set()
48
- for col in symptom_cols:
49
- for val in df1[col].unique():
50
- if val != 'none':
51
- all_symptoms.add(val)
52
- print(f"Unique symptoms: {len(all_symptoms)}")
53
-
54
- print(all_symptoms)
55
-
56
- df1.head()
57
-
58
- df1_num = pd.DataFrame(df1['disease'])
59
-
60
- for symptom in all_symptoms:
61
- df1_num[symptom] = df1[symptom_cols].apply(lambda row: int(symptom in row.values), axis=1)
62
-
63
- df1_num
64
-
65
- X = df1_num.drop('disease', axis=1)
66
- y = df1_num['disease']
67
- X.shape, y.shape
68
-
69
- X.sum(axis=1)
70
-
71
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y)
72
-
73
- print(np.unique(y_train, return_counts=True))
74
- print(np.unique(y_test, return_counts=True))
75
-
76
- from sklearn.ensemble import RandomForestClassifier
77
-
78
- model = RandomForestClassifier(n_estimators=100,random_state=42)
79
- model.fit(X_train, y_train)
80
- model.fit(X_train, y_train)
81
-
82
- import pickle
83
-
84
- # Save model
85
- with open("disease_model.pkl", "wb") as f:
86
- pickle.dump(model, f)
87
-
88
- # Save symptom list (to use in the app later)
89
- with open("symptoms.pkl", "wb") as f:
90
- pickle.dump(list(all_symptoms), f)
91
-
92
- # Original symptoms (keys)
93
- all_symptoms = sorted(all_symptoms)
94
-
95
- # Create display labels by replacing '_' with ' ' and capitalizing each word
96
- display_symptoms = [symptom.replace('_', ' ').title() for symptom in all_symptoms]
97
-
98
- # Create a mapping from display label back to original symptom key
99
- label_to_symptom = dict(zip(display_symptoms, all_symptoms))
100
-
101
- from sklearn.metrics import accuracy_score, f1_score
102
-
103
- y_train_pred = model.predict(X_train)
104
-
105
- train_accuracy = accuracy_score(y_train, y_train_pred)
106
- train_f1_score = f1_score(y_train, y_train_pred,average="weighted")
107
-
108
- print("Train Accuracy:", train_accuracy)
109
- print("Train f1 score:", train_f1_score)
110
-
111
- y_test_pred = model.predict(X_test)
112
- test_accuracy = accuracy_score(y_test, y_test_pred)
113
- test_f1_score = f1_score(y_test, y_test_pred, average="weighted")
114
- print("Train Accuracy:", test_accuracy)
115
- print("Train f1 score:", test_f1_score)
116
-
117
- import numpy as np
118
-
119
- # Example user symptoms
120
- user_symptoms = ['nausea', 'vomiting', 'abdominal_pain', 'diarrhoea']
121
-
122
- # Tip for the user
123
- if len(user_symptoms) < 4:
124
- print("Tip: The model performs better if you enter at least 4 symptoms.\n")
125
-
126
- # Convert symptoms to input vector
127
- input_vector = [1 if symptom in user_symptoms else 0 for symptom in all_symptoms]
128
- input_vector = np.array([input_vector])
129
-
130
- # Make prediction and get probabilities
131
- probas = model.predict_proba(input_vector)[0]
132
- max_proba = np.max(probas)
133
- predicted = model.classes_[np.argmax(probas)]
134
-
135
- # Confidence threshold
136
- threshold = 0.5
137
-
138
- # Print predicted disease and confidence
139
- if max_proba < threshold:
140
- print("Warning: The model is not confident about this prediction.")
141
- print(f"Predicted disease: {predicted} (Confidence: {max_proba * 100:.1f}%)")
142
- else:
143
- print(f"Predicted disease: {predicted} (Confidence: {max_proba * 100:.1f}%)")
144
-
145
- # Function to print top N diseases
146
- def print_top_diseases(probas, model, top_n=5):
147
- classes = model.classes_
148
- sorted_indices = np.argsort(probas)[::-1]
149
- print(f"\nTop {top_n} possible diseases:")
150
- for i in range(min(top_n, len(classes))):
151
- disease = classes[sorted_indices[i]]
152
- probability = probas[sorted_indices[i]]
153
- print(f"{i+1}. {disease}: {probability:.4f}")
154
-
155
- # Show top 5 possible diseases
156
- print_top_diseases(probas, model, top_n=5)
157
-
158
-
159
  import gradio as gr
160
- import pickle
161
- import numpy as np
162
 
163
  # --- 1. Load Disease Prediction Model ---
 
164
  with open("disease_model.pkl", "rb") as f:
165
  model = pickle.load(f)
166
 
@@ -174,12 +18,10 @@ label_to_symptom = dict(zip(display_symptoms, all_symptoms))
174
 
175
  # --- 2. Medical Knowledge Base ---
176
  MEDICAL_KNOWLEDGE = {
177
-
178
  "migraine": [
179
  "For migraines: (1) Rest in dark room (2) OTC pain relievers (ibuprofen/acetaminophen) (3) Apply cold compress (4) Consult neurologist if frequent",
180
  "Migraine treatment options include triptans (prescription) and caffeine. Avoid triggers like bright lights or strong smells."
181
  ],
182
-
183
  "allergy": [
184
  "Allergy management: (1) Antihistamines (cetirizine/loratadine) (2) Nasal sprays (3) Allergy shots (immunotherapy) for severe cases",
185
  "For food allergies: Strict avoidance, carry epinephrine auto-injector (EpiPen), read food labels carefully"
@@ -275,13 +117,31 @@ body, .gradio-container {
275
  color: var(--text) !important;
276
  font-family: 'Segoe UI', Roboto, sans-serif;
277
  }
278
- /* [Keep all your existing CSS styles] */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  """
280
 
281
- with gr.Blocks(css=custom_css) as demo:
282
  gr.Markdown("""
283
  <div style="text-align: center; margin-bottom: 20px;">
284
- <h1 style="margin-bottom: 5px;">🧬 Medical Diagnosis Assistant</h1>
285
  <p style="color: #4fc3f7; font-size: 16px;">Select symptoms for diagnosis and get medical advice</p>
286
  </div>
287
  """)
@@ -295,9 +155,9 @@ with gr.Blocks(css=custom_css) as demo:
295
  interactive=True
296
  )
297
  predict_btn = gr.Button("Analyze Symptoms", variant="primary")
298
- prediction_output = gr.Markdown(
299
  label="Diagnosis Results",
300
- value="Your results will appear here..."
301
  )
302
 
303
  with gr.Column(scale=1, min_width=400):
@@ -305,7 +165,10 @@ with gr.Blocks(css=custom_css) as demo:
305
  chatbot = gr.Chatbot(
306
  label="Chat with Medical Advisor",
307
  show_label=False,
308
- bubble_full_width=False
 
 
 
309
  )
310
  with gr.Row():
311
  user_input = gr.Textbox(
@@ -334,4 +197,5 @@ with gr.Blocks(css=custom_css) as demo:
334
  outputs=[chatbot, user_input]
335
  )
336
 
 
337
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
1
  import pandas as pd
2
  import numpy as np
 
 
 
 
3
  import pickle
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import gradio as gr
 
 
5
 
6
  # --- 1. Load Disease Prediction Model ---
7
+ # Load model and symptoms from files (upload these to your HF Space)
8
  with open("disease_model.pkl", "rb") as f:
9
  model = pickle.load(f)
10
 
 
18
 
19
  # --- 2. Medical Knowledge Base ---
20
  MEDICAL_KNOWLEDGE = {
 
21
  "migraine": [
22
  "For migraines: (1) Rest in dark room (2) OTC pain relievers (ibuprofen/acetaminophen) (3) Apply cold compress (4) Consult neurologist if frequent",
23
  "Migraine treatment options include triptans (prescription) and caffeine. Avoid triggers like bright lights or strong smells."
24
  ],
 
25
  "allergy": [
26
  "Allergy management: (1) Antihistamines (cetirizine/loratadine) (2) Nasal sprays (3) Allergy shots (immunotherapy) for severe cases",
27
  "For food allergies: Strict avoidance, carry epinephrine auto-injector (EpiPen), read food labels carefully"
 
117
  color: var(--text) !important;
118
  font-family: 'Segoe UI', Roboto, sans-serif;
119
  }
120
+ .gr-button {
121
+ background: var(--primary) !important;
122
+ color: var(--secondary) !important;
123
+ border: none !important;
124
+ }
125
+ .gr-button:hover {
126
+ opacity: 0.9 !important;
127
+ }
128
+ .gr-checkbox {
129
+ background: var(--card-bg) !important;
130
+ border-color: var(--primary) !important;
131
+ }
132
+ .gr-checkbox label {
133
+ color: var(--text) !important;
134
+ }
135
+ .gr-interface {
136
+ max-width: 1200px !important;
137
+ margin: 0 auto !important;
138
+ }
139
  """
140
 
141
+ with gr.Blocks(css=custom_css, title="Medical Diagnosis Assistant") as demo:
142
  gr.Markdown("""
143
  <div style="text-align: center; margin-bottom: 20px;">
144
+ <h1 style="margin-bottom: 5px; color: #4fc3f7;">🧬 Medical Diagnosis Assistant</h1>
145
  <p style="color: #4fc3f7; font-size: 16px;">Select symptoms for diagnosis and get medical advice</p>
146
  </div>
147
  """)
 
155
  interactive=True
156
  )
157
  predict_btn = gr.Button("Analyze Symptoms", variant="primary")
158
+ prediction_output = gr.HTML(
159
  label="Diagnosis Results",
160
+ value="<div style='padding: 20px; background: #001a33; border-radius: 8px; color: white;'>Your results will appear here...</div>"
161
  )
162
 
163
  with gr.Column(scale=1, min_width=400):
 
165
  chatbot = gr.Chatbot(
166
  label="Chat with Medical Advisor",
167
  show_label=False,
168
+ bubble_full_width=False,
169
+ avatar_images=(
170
+ None, (None, "assets/doctor_avatar.png")
171
+ )
172
  )
173
  with gr.Row():
174
  user_input = gr.Textbox(
 
197
  outputs=[chatbot, user_input]
198
  )
199
 
200
+ # For Hugging Face Spaces
201
  demo.launch()