File size: 3,405 Bytes
c96e0d1
 
 
 
 
 
 
3ba7ae0
 
 
 
 
c96e0d1
 
78cdd85
c96e0d1
 
36586fd
a0d6908
 
 
 
c96e0d1
36586fd
3ba7ae0
 
 
 
a0d6908
3ba7ae0
36586fd
c96e0d1
3ba7ae0
36586fd
c96e0d1
36586fd
a0d6908
c96e0d1
36586fd
a0d6908
c96e0d1
 
 
3ba7ae0
36586fd
a0d6908
 
c96e0d1
 
a0d6908
36586fd
c96e0d1
a0d6908
 
 
 
36586fd
 
 
 
 
 
a0d6908
 
3ba7ae0
 
c96e0d1
 
 
 
 
 
 
a0d6908
c96e0d1
 
36586fd
 
c96e0d1
36586fd
 
 
c96e0d1
 
 
36586fd
 
c96e0d1
 
36586fd
 
c96e0d1
36586fd
c96e0d1
 
36586fd
 
 
 
 
 
 
 
c96e0d1
 
 
 
 
36586fd
 
c96e0d1
a0d6908
c96e0d1
 
 
36586fd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
import os
import pandas as pd
import logging

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Configuration
MODEL_REPO = "Ahmedhassan54/Image-Classification"
MODEL_FILE = "best_model.h5"

# Initialize model
model = None

def load_model():
    global model
    try:
        logger.info("Downloading model...")
        model_path = hf_hub_download(
            repo_id=MODEL_REPO,
            filename=MODEL_FILE,
            cache_dir=".",
            force_download=True
        )
        logger.info(f"Model path: {model_path}")
        
        model = tf.keras.models.load_model(model_path)
        logger.info("Model loaded successfully!")
    except Exception as e:
        logger.error(f"Model loading failed: {str(e)}")
        model = None

# Load model at startup
load_model()

def classify_image(image):
    try:
        if image is None:
            return {"Cat": 0.5, "Dog": 0.5}, pd.DataFrame({'Class': ['Cat', 'Dog'], 'Confidence': [0.5, 0.5]})
            
        # Convert to PIL Image if numpy array
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)
            
        # Preprocess
        image = image.resize((150, 150))
        img_array = np.array(image) / 255.0
        if len(img_array.shape) == 3:
            img_array = np.expand_dims(img_array, axis=0)
            
        # Predict
        if model is not None:
            pred = model.predict(img_array, verbose=0)
            confidence = float(pred[0][0])
        else:
            confidence = 0.75  # Demo value
            
        results = {
            "Cat": round(1 - confidence, 4),
            "Dog": round(confidence, 4)
        }
        
        plot_data = pd.DataFrame({
            'Class': ['Cat', 'Dog'],
            'Confidence': [1 - confidence, confidence]
        })
        
        return results, plot_data
        
    except Exception as e:
        logger.error(f"Error: {str(e)}")
        return {"Error": str(e)}, pd.DataFrame()

# Interface
with gr.Blocks() as demo:
    gr.Markdown("# 🐾 Cat vs Dog Classifier 🦮")
    
    with gr.Row():
        with gr.Column():
            img_input = gr.Image(type="pil")
            classify_btn = gr.Button("Classify", variant="primary")
        
        with gr.Column():
            label_out = gr.Label(num_top_classes=2)
            plot_out = gr.BarPlot(
                pd.DataFrame({'Class': ['Cat', 'Dog'], 'Confidence': [0.5, 0.5]}),
                x="Class", y="Confidence", y_lim=[0,1]
            )
    
    # Fixed button click handler - removed api_name
    classify_btn.click(
        fn=classify_image,
        inputs=img_input,
        outputs=[label_out, plot_out]
    )

    # Examples section
    gr.Examples(
        examples=[
            ["https://upload.wikimedia.org/wikipedia/commons/1/15/Cat_August_2010-4.jpg"],
            ["https://upload.wikimedia.org/wikipedia/commons/d/d9/Collage_of_Nine_Dogs.jpg"]
        ],
        inputs=img_input,
        outputs=[label_out, plot_out],
        fn=classify_image,
        cache_examples=True
    )

if __name__ == "__main__":
    demo.launch()