Ahmedhassan54's picture
Upload app.py
3ba7ae0 verified
raw
history blame
5.65 kB
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
import os
import pandas as pd
import logging
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configuration
MODEL_REPO = "Ahmedhassan54/Image-Classification"
MODEL_FILE = "best_model.h5"
# Download model from Hugging Face Hub
def load_model_from_hf():
try:
logger.info("Attempting to load model...")
# Check if model exists in cache
model_path = hf_hub_download(
repo_id=MODEL_REPO,
filename=MODEL_FILE,
cache_dir=".",
force_download=True # Ensure fresh download
)
logger.info(f"Model downloaded to: {model_path}")
# Load model
logger.info("Loading model...")
model = tf.keras.models.load_model(model_path)
logger.info("Model loaded successfully!")
return model
except Exception as e:
logger.error(f"Model loading failed: {str(e)}")
raise gr.Error(f"⚠️ Model loading failed: {str(e)}. Check the logs for details.")
# Load model when the app starts
try:
model = load_model_from_hf()
except Exception as e:
model = None
logger.error(f"Proceeding without model due to: {str(e)}")
def classify_image(image):
try:
logger.info("\nClassification started...")
# Debug: Check input type
logger.info(f"Input type: {type(image)}")
if image is None:
raise ValueError("No image provided")
# Convert image if needed
if isinstance(image, np.ndarray):
logger.info("Converting numpy array to PIL Image")
image = Image.fromarray(image)
elif not isinstance(image, Image.Image):
raise ValueError(f"Unexpected image type: {type(image)}")
# Preprocess image
logger.info("Preprocessing image...")
image = image.resize((150, 150))
image_array = np.array(image) / 255.0
# Add batch dimension
if len(image_array.shape) == 3:
image_array = np.expand_dims(image_array, axis=0)
logger.info(f"Image array shape: {image_array.shape}")
# Make prediction
logger.info("Making prediction...")
if model is None:
raise gr.Error("Model failed to load. Cannot make predictions.")
prediction = model.predict(image_array, verbose=0)
logger.info(f"Raw prediction: {prediction}")
confidence = float(prediction[0][0])
logger.info(f"Confidence score: {confidence}")
# Format outputs
label_output = {
"Cat": round(1 - confidence, 4),
"Dog": round(confidence, 4)
}
# Create dataframe for bar plot
plot_data = pd.DataFrame({
'Class': ['Cat', 'Dog'],
'Confidence': [1 - confidence, confidence]
})
logger.info("Classification successful!")
logger.info(f"Results: {label_output}")
return label_output, plot_data
except Exception as e:
logger.error(f"Error during classification: {str(e)}", exc_info=True)
raise gr.Error(f"🔴 Classification failed: {str(e)}")
# Custom CSS
css = """
.gradio-container {
background: linear-gradient(to right, #f5f7fa, #c3cfe2);
}
footer {
visibility: hidden
}
.error-message {
color: red !important;
font-weight: bold !important;
}
"""
# Build the interface
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🐾 Cat vs Dog Classifier 🦮
Upload an image to classify whether it's a cat or dog
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Upload Image", type="pil")
with gr.Row():
submit_btn = gr.Button("Classify", variant="primary")
clear_btn = gr.Button("Clear")
with gr.Column():
label_output = gr.Label(label="Predictions", num_top_classes=2)
confidence_bar = gr.BarPlot(
pd.DataFrame({'Class': ['Cat', 'Dog'], 'Confidence': [0.5, 0.5]}),
x="Class",
y="Confidence",
y_lim=[0,1],
title="Confidence Scores",
width=400,
height=300,
container=False
)
# Example images
gr.Examples(
examples=[
["https://upload.wikimedia.org/wikipedia/commons/1/15/Cat_August_2010-4.jpg"],
["https://upload.wikimedia.org/wikipedia/commons/d/d9/Collage_of_Nine_Dogs.jpg"]
],
inputs=image_input,
outputs=[label_output, confidence_bar],
fn=classify_image,
cache_examples=True,
label="Try these examples:"
)
# Button actions
submit_btn.click(
fn=classify_image,
inputs=image_input,
outputs=[label_output, confidence_bar],
api_name="classify"
)
clear_btn.click(
fn=lambda: [None, pd.DataFrame({'Class': ['Cat', 'Dog'], 'Confidence': [0.5, 0.5]})],
inputs=None,
outputs=[image_input, confidence_bar],
show_progress=False
)
# Launch the app
if __name__ == "__main__":
demo.launch(debug=True)