Spaces:
Sleeping
Sleeping
import os | |
import requests | |
import pandas as pd | |
import numpy as np | |
import joblib | |
import gradio as gr | |
from datetime import datetime, timedelta | |
from tensorflow.keras.models import load_model | |
from tensorflow.keras.preprocessing import image as keras_image | |
from tensorflow.keras.applications.vgg16 import preprocess_input as vgg_preprocess | |
from tensorflow.keras.applications.xception import preprocess_input as xce_preprocess | |
from tensorflow.keras.losses import BinaryFocalCrossentropy | |
from PIL import Image | |
# --- CONFIGURATION --- | |
FOREST_COORDS = {'Pakistan Forest': (34.0, 73.0)} | |
API_URL = ( | |
"https://archive-api.open-meteo.com/v1/archive" | |
"?latitude={lat}&longitude={lon}" | |
"&start_date={start}&end_date={end}" | |
"&daily=temperature_2m_max,temperature_2m_min," | |
"precipitation_sum,windspeed_10m_max," | |
"relative_humidity_2m_max,relative_humidity_2m_min" | |
"&timezone=UTC" | |
) | |
# --- LOAD MODELS --- | |
def load_models(): | |
try: | |
vgg_model = load_model( | |
'vgg16_focal_unfreeze_more.keras', | |
custom_objects={'BinaryFocalCrossentropy': BinaryFocalCrossentropy} | |
) | |
def focal_loss_fixed(gamma=2., alpha=.25): | |
import tensorflow.keras.backend as K | |
def loss_fn(y_true, y_pred): | |
eps = K.epsilon(); y_pred = K.clip(y_pred, eps, 1.-eps) | |
ce = -y_true * K.log(y_pred) | |
w = alpha * K.pow(1-y_pred, gamma) | |
return K.mean(w * ce, axis=-1) | |
return loss_fn | |
xce_model = load_model( | |
'severity_post_tta.keras', | |
custom_objects={'focal_loss_fixed': focal_loss_fixed()} | |
) | |
rf_model = joblib.load('ensemble_rf_model.pkl') | |
xgb_model = joblib.load('ensemble_xgb_model.pkl') | |
lr_model = joblib.load('wildfire_logistic_model_synthetic.joblib') | |
return vgg_model, xce_model, rf_model, xgb_model, lr_model | |
except Exception as e: | |
print(f"Error loading models: {e}") | |
return None, None, None, None, None | |
# --- RULES & TEMPLATES --- | |
target_map = {0: 'Mild', 1: 'Moderate', 2: 'Severe'} | |
trend_map = {1: 'Increasing', 0: 'Stable', -1: 'Decreasing'} | |
# Severity progression rules based on current severity and weather trend | |
task_rules = { | |
'Mild': {'Decreasing':'Mild','Stable':'Mild','Increasing':'Moderate'}, | |
'Moderate':{'Decreasing':'Mild','Stable':'Moderate','Increasing':'Severe'}, | |
'Severe': {'Decreasing':'Moderate','Stable':'Severe','Increasing':'Severe'} | |
} | |
recommendations = { ... } # same as before | |
# --- PIPELINE FUNCTIONS --- | |
def detect_fire(img): ... | |
def classify_severity(img): ... | |
def fetch_weather_trend(lat, lon): ... | |
def generate_recommendations(original_severity, weather_trend): ... | |
# --- MAIN PIPELINE --- | |
def pipeline(image, progress=gr.Progress()): | |
progress(0.1, "Analyzing image…") | |
if image is None: | |
return ("No image provided", "N/A", "N/A", "**Please upload an image to analyze.**") | |
img = Image.fromarray(image).convert('RGB') | |
fire, prob = detect_fire(img) | |
progress(0.3, "Detecting fire presence…") | |
if not fire: | |
return (f"✅ No wildfire detected (confidence {(1-prob)*100:.1f}% )", "N/A", "N/A", "**No wildfire detected. Continue monitoring.**") | |
severity = classify_severity(img) | |
progress(0.6, "Classifying severity…") | |
trend = fetch_weather_trend(*FOREST_COORDS['Pakistan Forest']) | |
progress(0.8, "Computing recommendations…") | |
recs = generate_recommendations(severity, trend) | |
return (f"🔥 Wildfire detected! Confidence: {prob*100:.1f}%", severity, trend, recs) | |
vgg_model, xception_model, rf_model, xgb_model, lr_model = load_models() | |
# --- GRADIO BLOCKS UI --- | |
css = ''' | |
.sidebar { background: #111827; color: #F9FAFB; padding: 1rem; border-radius: 0.5rem; } | |
.card { background: #FFFFFF; border-radius: 1rem; box-shadow: 0 4px 12px rgba(0,0,0,0.1); padding: 1rem; margin-bottom: 1rem; } | |
#title { font-size: 2.25rem; font-weight: 700; color: #1F2937; } | |
#desc { font-size: 1rem; color: #4B5563; margin-bottom: 1rem; } | |
.gr-button { background: #EF4444 !important; color: white !important; border-radius: 0.75rem; padding: 0.75rem 1.25rem; } | |
''' | |
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("<div id='title'>Wildfire Command Center</div>", elem_id="title") | |
gr.Markdown("<div id='desc'>Upload a forest image from Pakistan to detect wildfires, assess severity, forecast weather-driven trends, and receive expert management plans.</div>", elem_id="desc") | |
image_input = gr.Image(type='numpy', label='Upload Forest Image', tool='editor') | |
run_btn = gr.Button("🔍 Analyze Now", variant="primary") | |
with gr.Column(scale=0.6, elem_classes="sidebar"): | |
gr.Markdown("### Last Analysis", elem_classes="card") | |
last_status = gr.Textbox(label='Fire Status', interactive=False) | |
last_severity = gr.Textbox(label='Severity Level', interactive=False) | |
last_trend = gr.Textbox(label='Weather Trend', interactive=False) | |
last_recs = gr.Markdown(label='Recommendations', interactive=False) | |
run_btn.click( | |
fn=pipeline, | |
inputs=image_input, | |
outputs=[last_status, last_severity, last_trend, last_recs] | |
) | |
if __name__ == '__main__': | |
demo.queue(api_open=True).launch(share=False) |