EmberDeepAI / app.py
AbdullahImran's picture
Update app.py
6594ae4 verified
raw
history blame
2.26 kB
import gradio as gr
import tensorflow as tf
from PIL import Image
import numpy as np
import tensorflow as tf
import keras.backend as K
# Define focal loss
def focal_loss(gamma=2., alpha=0.25):
def focal_loss_fixed(y_true, y_pred):
y_pred = tf.clip_by_value(y_pred, K.epsilon(), 1. - K.epsilon())
cross_entropy = -y_true * tf.math.log(y_pred)
weight = alpha * y_true * tf.math.pow((1 - y_pred), gamma)
loss = weight * cross_entropy
return tf.reduce_sum(loss, axis=-1)
return focal_loss_fixed
# Load models
vgg16_model = tf.keras.models.load_model(
"vgg16_best_model.keras"
)
xception_model = tf.keras.models.load_model(
'xception_best.keras',
custom_objects={'focal_loss_fixed': focal_loss()}
)
def predict_fire(image):
img = Image.fromarray(image).convert("RGB")
# Preprocess for vgg16_model (128x128 input size)
vgg16_img = img.resize((128, 128))
vgg16_img_array = np.array(vgg16_img) / 255.0
vgg16_img_array = np.expand_dims(vgg16_img_array, axis=0)
# Fire detection using vgg16_model
fire_pred = vgg16_model.predict(vgg16_img_array)
fire_status = "Fire Detected" if fire_pred[0][0] > 0.5 else "No Fire Detected"
# If fire is detected, preprocess for xception_model (224x224 input size)
if fire_status == "Fire Detected":
xception_img = img.resize((224, 224))
xception_img_array = np.array(xception_img) / 255.0
xception_img_array = np.expand_dims(xception_img_array, axis=0)
# Severity prediction using xception_model
severity_pred = xception_model.predict(xception_img_array)
severity_level = np.argmax(severity_pred[0])
severity = ["Mild", "Moderate", "Severe"][severity_level]
else:
severity = "N/A"
return fire_status, severity
# Gradio interface
interface = gr.Interface(
fn=predict_fire,
inputs=gr.Image(type="numpy", label="Upload Image"),
outputs=[
gr.Textbox(label="Fire Status"),
gr.Textbox(label="Severity Level"),
],
title="Fire Prediction and Severity Classification",
description="Upload an image to predict fire and its severity level (Mild, Moderate, Severe).",
)
if __name__ == "__main__":
interface.launch()