import gradio as gr import tensorflow as tf import numpy as np from PIL import Image from huggingface_hub import hf_hub_download import os import pandas as pd import logging os.environ['CUDA_VISIBLE_DEVICES'] = '-1' logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) MODEL_REPO = "Ahmedhassan54/Image-Classification-Model" MODEL_FILE = "best_model.keras" model = None def load_model(): global model try: logger.info("Downloading model...") model_path = hf_hub_download( repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".", force_download=True ) logger.info(f"Model path: {model_path}") with tf.device('/CPU:0'): model = tf.keras.models.load_model(model_path) logger.info("Model loaded successfully!") except Exception as e: logger.error(f"Model loading failed: {str(e)}") model = None load_model() def classify_image(image): try: if image is None: return {"Cat": 0.5, "Dog": 0.5}, pd.DataFrame({'Class': ['Cat', 'Dog'], 'Confidence': [0.5, 0.5]}) if isinstance(image, np.ndarray): image = Image.fromarray(image.astype('uint8')) image = image.resize((150, 150)) img_array = np.array(image) / 255.0 if len(img_array.shape) == 3: img_array = np.expand_dims(img_array, axis=0) if model is not None: with tf.device('/CPU:0'): pred = model.predict(img_array, verbose=0) confidence = float(pred[0][0]) else: confidence = 0.75 results = { "Cat": round(1 - confidence, 4), "Dog": round(confidence, 4) } plot_data = pd.DataFrame({ 'Class': ['Cat', 'Dog'], 'Confidence': [1 - confidence, confidence] }) return results, plot_data except Exception as e: logger.error(f"Error: {str(e)}") return {"Error": str(e)}, pd.DataFrame({'Class': ['Cat', 'Dog'], 'Confidence': [0.5, 0.5]}) with gr.Blocks() as demo: gr.Markdown("# 🐾 Cat vs Dog Classifier 🦮") with gr.Row(): with gr.Column(): img_input = gr.Image(type="pil") classify_btn = gr.Button("Classify", variant="primary") with gr.Column(): label_out = gr.Label(num_top_classes=2) plot_out = gr.BarPlot( pd.DataFrame({'Class': ['Cat', 'Dog'], 'Confidence': [0.5, 0.5]}), x="Class", y="Confidence", y_lim=[0,1] ) classify_btn.click( classify_image, inputs=img_input, outputs=[label_out, plot_out] ) gr.Examples( examples=[ ["https://upload.wikimedia.org/wikipedia/commons/1/15/Cat_August_2010-4.jpg"], ["https://upload.wikimedia.org/wikipedia/commons/d/d9/Collage_of_Nine_Dogs.jpg"] ], inputs=img_input, outputs=[label_out, plot_out], fn=classify_image, cache_examples=True ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)