AbdullahImran commited on
Commit
b725fc0
·
verified ·
1 Parent(s): 7da1f68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -154
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import os
2
  import requests
3
  import pandas as pd
@@ -26,45 +28,39 @@ API_URL = (
26
 
27
  # --- LOAD MODELS ---
28
  def load_models():
29
- try:
30
- vgg_model = load_model(
31
- 'vgg16_focal_unfreeze_more.keras',
32
- custom_objects={'BinaryFocalCrossentropy': BinaryFocalCrossentropy}
33
- )
34
-
35
- def focal_loss_fixed(gamma=2., alpha=.25):
36
- import tensorflow.keras.backend as K
37
- def loss_fn(y_true, y_pred):
38
- eps = K.epsilon()
39
- y_pred = K.clip(y_pred, eps, 1. - eps)
40
- ce = -y_true * K.log(y_pred)
41
- w = alpha * K.pow(1 - y_pred, gamma)
42
- return K.mean(w * ce, axis=-1)
43
- return loss_fn
44
-
45
- xce_model = load_model(
46
- 'severity_post_tta.keras',
47
- custom_objects={'focal_loss_fixed': focal_loss_fixed()}
48
- )
49
- rf_model = joblib.load('ensemble_rf_model.pkl')
50
- xgb_model = joblib.load('ensemble_xgb_model.pkl')
51
- lr_model = joblib.load('wildfire_logistic_model_synthetic.joblib')
52
- return vgg_model, xce_model, rf_model, xgb_model, lr_model
53
- except Exception as e:
54
- raise gr.Error(f"Model loading failed: {str(e)}")
55
 
56
- try:
57
- vgg_model, xception_model, rf_model, xgb_model, lr_model = load_models()
58
- except Exception as e:
59
- print(f"Initial model loading failed: {str(e)}")
60
 
61
  # --- RULES & TEMPLATES ---
62
  target_map = {0: 'mild', 1: 'moderate', 2: 'severe'}
63
  trend_map = {1: 'increase', 0: 'same', -1: 'decrease'}
64
  task_rules = {
65
- 'mild': {'decrease': 'mild', 'same': 'mild', 'increase': 'moderate'},
66
- 'moderate': {'decrease': 'mild', 'same': 'moderate', 'increase': 'severe'},
67
- 'severe': {'decrease': 'moderate', 'same': 'severe', 'increase': 'severe'}
68
  }
69
  templates = {
70
  'mild': (
@@ -90,138 +86,83 @@ templates = {
90
  )
91
  }
92
 
93
- # --- FUNCTIONS ---
94
  def detect_fire(img):
95
- img_resized = img.resize((224, 224))
96
- arr = keras_image.img_to_array(img_resized)
97
- arr = np.expand_dims(arr, axis=0)
98
- arr = vgg_preprocess(arr)
99
- pred = vgg_model.predict(arr)[0][0]
100
- is_fire = pred >= 0.5
101
- return is_fire, pred
102
-
103
- def classify_severity(img):
104
- img_resized = img.resize((224, 224))
105
- arr = keras_image.img_to_array(img_resized)
106
- arr = np.expand_dims(arr, axis=0)
107
- arr = xce_preprocess(arr)
108
- feat = np.squeeze(arr)
109
- feat_flat = feat.flatten().reshape(1, -1)
110
-
111
- rf_pred = rf_model.predict_proba(feat_flat)
112
- xgb_pred = xgb_model.predict_proba(feat_flat)
113
- avg_pred = (rf_pred + xgb_pred) / 2
114
- final_class = np.argmax(avg_pred)
115
- return target_map[final_class]
116
 
117
- def fetch_weather_trend(lat, lon):
118
- today = datetime.utcnow().date()
119
- start_date = today - timedelta(days=2)
120
- end_date = today - timedelta(days=1)
121
-
122
- url = API_URL.format(lat=lat, lon=lon, start=start_date, end=end_date)
123
- response = requests.get(url)
124
- if response.status_code != 200:
125
- return 'same' # fallback if API fails
126
-
127
- data = response.json()
128
- temp_max = data['daily']['temperature_2m_max']
129
- wind_max = data['daily']['windspeed_10m_max']
130
- humidity_min = data['daily']['relative_humidity_2m_min']
131
 
132
- # crude trend logic: hotter, windier = worse
133
- temp_trend = np.sign(temp_max[-1] - temp_max[0])
134
- wind_trend = np.sign(wind_max[-1] - wind_max[0])
135
- humidity_trend = -np.sign(humidity_min[-1] - humidity_min[0])
 
 
 
 
136
 
137
- overall_trend = temp_trend + wind_trend + humidity_trend
138
- if overall_trend > 0:
139
- return 'increase'
140
- elif overall_trend < 0:
141
- return 'decrease'
142
- else:
143
- return 'same'
144
 
145
- def generate_recommendations(original, trend):
146
- projected = task_rules[original][trend]
147
- header = (
148
- f"## 🔥 Wildfire Situation Update\n"
149
- f"- **Original Severity:** {original.title()}\n"
150
- f"- **Weather Trend:** {trend.title()}\n"
151
- f"- **Projected Severity:** {projected.title()}\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
152
  )
153
- paras = templates[projected].split("\n\n")
154
- formatted = "\n\n".join(paras)
155
- return header + formatted
156
-
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def pipeline(image):
158
  img = Image.fromarray(image).convert('RGB')
159
  fire, prob = detect_fire(img)
160
  if not fire:
161
- return (
162
- f"**No wildfire detected** (probability={prob:.2f})",
163
- "N/A",
164
- "N/A",
165
- "There is currently no sign of wildfire in the image. Continue normal monitoring."
166
- )
167
  sev = classify_severity(img)
168
  trend = fetch_weather_trend(*FOREST_COORDS['Pakistan Forest'])
169
  recs = generate_recommendations(sev, trend)
170
- return (
171
- f"**🔥 Fire Detected** (probability={prob:.2f})",
172
- sev.title(),
173
- trend.title(),
174
- recs
175
- )
176
-
177
- # --- GRADIO APP ---
178
- custom_css = """
179
- #component-0 { max-width: 800px; margin: 0 auto; }
180
- .gradio-container { background: #f0f4f7; }
181
- #upload-wildfire-image { min-height: 300px; }
182
- .panel {
183
- background: white !important;
184
- border-radius: 12px !important;
185
- padding: 20px !important;
186
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
187
- }
188
- .status-box {
189
- background: #fff3e6 !important;
190
- border: 1px solid #ffd8b3 !important;
191
- }
192
- .dark-red { color: #cc0000 !important; }
193
- .green { color: #008000 !important; }
194
- """
195
-
196
- with gr.Blocks(css=custom_css) as demo:
197
- gr.Markdown("# 🔥 Wildfire Detection & Management Assistant")
198
-
199
- with gr.Row(variant="panel"):
200
- with gr.Column(scale=2):
201
- inp = gr.Image(type="numpy", label="Satellite Image", elem_id="upload-wildfire-image")
202
- with gr.Column(scale=1):
203
- status = gr.Textbox(label="Fire Status", interactive=False, elem_classes="status-box")
204
- severity = gr.Textbox(label="Severity Level", interactive=False)
205
- trend = gr.Textbox(label="Weather Trend", interactive=False)
206
-
207
- with gr.Accordion("📋 Detailed Recommendations", open=False):
208
- rec_box = gr.Markdown()
209
-
210
- btn = gr.Button("Analyze", variant="primary")
211
- btn.click(
212
- fn=pipeline,
213
- inputs=inp,
214
- outputs=[status, severity, trend, rec_box],
215
- api_name="analyze"
216
- )
217
-
218
- gr.Markdown("---")
219
- gr.HTML("<div style='text-align: center; color: #666;'>© 2025 ForestAI Labs</div>")
220
-
221
- def handle_errors(inputs, outputs):
222
- for output in outputs:
223
- if isinstance(output, Exception):
224
- raise gr.Error("Analysis failed. Please check the input and try again.")
225
 
226
- if __name__ == "__main__":
227
- demo.launch() # Fixed the space typo in launch()
 
1
+ improve the text in the rcommendation, display each point on next line and improve the UI a little bit
2
+
3
  import os
4
  import requests
5
  import pandas as pd
 
28
 
29
  # --- LOAD MODELS ---
30
  def load_models():
31
+ # Fire detector (VGG16)
32
+ vgg_model = load_model(
33
+ 'vgg16_focal_unfreeze_more.keras',
34
+ custom_objects={'BinaryFocalCrossentropy': BinaryFocalCrossentropy}
35
+ )
36
+ # Severity classifier (Xception)
37
+ def focal_loss_fixed(gamma=2., alpha=.25):
38
+ import tensorflow.keras.backend as K
39
+ def loss_fn(y_true, y_pred):
40
+ eps = K.epsilon(); y_pred = K.clip(y_pred, eps, 1.-eps)
41
+ ce = -y_true * K.log(y_pred)
42
+ w = alpha * K.pow(1-y_pred, gamma)
43
+ return K.mean(w * ce, axis=-1)
44
+ return loss_fn
45
+ xce_model = load_model(
46
+ 'severity_post_tta.keras',
47
+ custom_objects={'focal_loss_fixed': focal_loss_fixed()}
48
+ )
49
+ # Ensemble and trend models
50
+ rf_model = joblib.load('ensemble_rf_model.pkl')
51
+ xgb_model = joblib.load('ensemble_xgb_model.pkl')
52
+ lr_model = joblib.load('wildfire_logistic_model_synthetic.joblib')
53
+ return vgg_model, xce_model, rf_model, xgb_model, lr_model
 
 
 
54
 
55
+ vgg_model, xception_model, rf_model, xgb_model, lr_model = load_models()
 
 
 
56
 
57
  # --- RULES & TEMPLATES ---
58
  target_map = {0: 'mild', 1: 'moderate', 2: 'severe'}
59
  trend_map = {1: 'increase', 0: 'same', -1: 'decrease'}
60
  task_rules = {
61
+ 'mild': {'decrease':'mild','same':'mild','increase':'moderate'},
62
+ 'moderate':{'decrease':'mild','same':'moderate','increase':'severe'},
63
+ 'severe': {'decrease':'moderate','same':'severe','increase':'severe'}
64
  }
65
  templates = {
66
  'mild': (
 
86
  )
87
  }
88
 
89
+ # --- PIPELINE FUNCTIONS ---
90
  def detect_fire(img):
91
+ x = keras_image.img_to_array(img.resize((128,128)))[None]
92
+ x = vgg_preprocess(x)
93
+ prob = float(vgg_model.predict(x)[0][0])
94
+ return prob >= 0.5, prob
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
+ def classify_severity(img):
98
+ x = keras_image.img_to_array(img.resize((224,224)))[None]
99
+ x = xce_preprocess(x)
100
+ preds = xception_model.predict(x)
101
+ rf_p = rf_model.predict(preds)[0]
102
+ xgb_p = xgb_model.predict(preds)[0]
103
+ ensemble = int(round((rf_p + xgb_p)/2))
104
+ return target_map.get(ensemble, 'moderate')
105
 
 
 
 
 
 
 
 
106
 
107
+ def fetch_weather_trend(lat, lon):
108
+ end = datetime.utcnow()
109
+ start = end - timedelta(days=1)
110
+ url = API_URL.format(lat=lat, lon=lon,
111
+ start=start.strftime('%Y-%m-%d'),
112
+ end=end.strftime('%Y-%m-%d'))
113
+ df = pd.DataFrame(requests.get(url).json().get('daily', {}))
114
+ for c in ['precipitation_sum','temperature_2m_max','temperature_2m_min',
115
+ 'relative_humidity_2m_max','relative_humidity_2m_min','windspeed_10m_max']:
116
+ df[c] = pd.to_numeric(df.get(c,[]), errors='coerce')
117
+ df['precipitation'] = df['precipitation_sum'].fillna(0)
118
+ df['temperature'] = (df['temperature_2m_max'] + df['temperature_2m_min'])/2
119
+ df['humidity'] = (df['relative_humidity_2m_max'] + df['relative_humidity_2m_min'])/2
120
+ df['wind_speed'] = df['windspeed_10m_max']
121
+ df['fire_risk_score'] = (
122
+ 0.4*(df['temperature']/55) +
123
+ 0.2*(1-df['humidity']/100) +
124
+ 0.3*(df['wind_speed']/60) +
125
+ 0.1*(1-df['precipitation']/50)
126
  )
127
+ feats = df[['temperature','humidity','wind_speed','precipitation','fire_risk_score']]
128
+ feat = feats.fillna(feats.mean()).iloc[-1].values.reshape(1,-1)
129
+ trend_cl = lr_model.predict(feat)[0]
130
+ return trend_map.get(trend_cl, 'same')
131
+
132
+
133
+ def generate_recommendations(original_severity, weather_trend):
134
+ # determine projected severity
135
+ proj = task_rules[original_severity][weather_trend]
136
+ rec = templates[proj]
137
+ # proper multi-line header
138
+ header = f"""**Original:** {original_severity.title()}
139
+ **Trend:** {weather_trend.title()}
140
+ **Projected:** {proj.title()}\n\n"""
141
+ return header + rec
142
+
143
+ # --- GRADIO INTERFACE ---
144
  def pipeline(image):
145
  img = Image.fromarray(image).convert('RGB')
146
  fire, prob = detect_fire(img)
147
  if not fire:
148
+ return f"No wildfire detected (prob={prob:.2f})", "N/A", "N/A", "**No wildfire detected. Stay alert.**"
 
 
 
 
 
149
  sev = classify_severity(img)
150
  trend = fetch_weather_trend(*FOREST_COORDS['Pakistan Forest'])
151
  recs = generate_recommendations(sev, trend)
152
+ return f"Fire Detected (prob={prob:.2f})", sev.title(), trend, recs
153
+
154
+ interface = gr.Interface(
155
+ fn=pipeline,
156
+ inputs=gr.Image(type='numpy', label='Upload Wildfire Image'),
157
+ outputs=[
158
+ gr.Textbox(label='Fire Status'),
159
+ gr.Textbox(label='Severity Level'),
160
+ gr.Textbox(label='Weather Trend'),
161
+ gr.Markdown(label='Recommendations')
162
+ ],
163
+ title='Wildfire Detection & Management Assistant',
164
+ description='Upload an image from a forest region in Pakistan to determine wildfire presence, severity, weather-driven trend, projection, and get expert recommendations.'
165
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
+ if __name__ == '__main__':
168
+ interface.launch()