import gradio as gr import tensorflow as tf import numpy as np from PIL import Image from huggingface_hub import hf_hub_download import os import pandas as pd import logging # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Configuration MODEL_REPO = "Ahmedhassan54/Image-Classification" MODEL_FILE = "best_model.h5" # Initialize model model = None def load_model(): global model try: logger.info("Downloading model...") model_path = hf_hub_download( repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".", force_download=True ) logger.info(f"Model path: {model_path}") model = tf.keras.models.load_model(model_path) logger.info("Model loaded successfully!") except Exception as e: logger.error(f"Model loading failed: {str(e)}") model = None # Load model at startup load_model() def classify_image(image): try: if image is None: return {"Cat": 0.5, "Dog": 0.5}, pd.DataFrame({'Class': ['Cat', 'Dog'], 'Confidence': [0.5, 0.5]}) # Convert to PIL Image if numpy array if isinstance(image, np.ndarray): image = Image.fromarray(image) # Preprocess image = image.resize((150, 150)) img_array = np.array(image) / 255.0 if len(img_array.shape) == 3: img_array = np.expand_dims(img_array, axis=0) # Predict if model is not None: pred = model.predict(img_array, verbose=0) confidence = float(pred[0][0]) else: confidence = 0.75 # Demo value results = { "Cat": round(1 - confidence, 4), "Dog": round(confidence, 4) } plot_data = pd.DataFrame({ 'Class': ['Cat', 'Dog'], 'Confidence': [1 - confidence, confidence] }) return results, plot_data except Exception as e: logger.error(f"Error: {str(e)}") return {"Error": str(e)}, pd.DataFrame() # Interface with gr.Blocks() as demo: gr.Markdown("# 🐾 Cat vs Dog Classifier 🦮") with gr.Row(): with gr.Column(): img_input = gr.Image(type="pil") classify_btn = gr.Button("Classify", variant="primary") with gr.Column(): label_out = gr.Label(num_top_classes=2) plot_out = gr.BarPlot( pd.DataFrame({'Class': ['Cat', 'Dog'], 'Confidence': [0.5, 0.5]}), x="Class", y="Confidence", y_lim=[0,1] ) # Fixed button click handler - removed api_name classify_btn.click( fn=classify_image, inputs=img_input, outputs=[label_out, plot_out] ) # Examples section gr.Examples( examples=[ ["https://upload.wikimedia.org/wikipedia/commons/1/15/Cat_August_2010-4.jpg"], ["https://upload.wikimedia.org/wikipedia/commons/d/d9/Collage_of_Nine_Dogs.jpg"] ], inputs=img_input, outputs=[label_out, plot_out], fn=classify_image, cache_examples=True ) if __name__ == "__main__": demo.launch()