File size: 3,301 Bytes
4bafa27
 
 
 
 
 
 
 
 
217a3bc
4bafa27
 
217a3bc
4bafa27
 
 
217a3bc
cd90460
 
4bafa27
217a3bc
4bafa27
 
 
 
 
 
 
 
 
 
 
 
 
 
217a3bc
4bafa27
 
 
 
 
 
 
217a3bc
4bafa27
 
 
 
 
 
 
217a3bc
4bafa27
 
 
217a3bc
4bafa27
 
 
 
 
217a3bc
4bafa27
 
 
 
 
217a3bc
4bafa27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217a3bc
4bafa27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217a3bc
4bafa27
 
 
 
 
 
 
 
 
 
 
 
3a1a754
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
import os
import pandas as pd
import logging


os.environ['CUDA_VISIBLE_DEVICES'] = '-1'


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


MODEL_REPO = "Ahmedhassan54/Image-Classification-Model"
MODEL_FILE = "best_model.keras"


model = None

def load_model():
    global model
    try:
        logger.info("Downloading model...")
        model_path = hf_hub_download(
            repo_id=MODEL_REPO,
            filename=MODEL_FILE,
            cache_dir=".",
            force_download=True
        )
        logger.info(f"Model path: {model_path}")
        
       
        with tf.device('/CPU:0'):
            model = tf.keras.models.load_model(model_path)
        logger.info("Model loaded successfully!")
    except Exception as e:
        logger.error(f"Model loading failed: {str(e)}")
        model = None


load_model()

def classify_image(image):
    try:
        if image is None:
            return {"Cat": 0.5, "Dog": 0.5}, pd.DataFrame({'Class': ['Cat', 'Dog'], 'Confidence': [0.5, 0.5]})
            
       
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image.astype('uint8'))
            
     
        image = image.resize((150, 150))
        img_array = np.array(image) / 255.0
        if len(img_array.shape) == 3:
            img_array = np.expand_dims(img_array, axis=0)
            
    
        if model is not None:
            with tf.device('/CPU:0'):
                pred = model.predict(img_array, verbose=0)
            confidence = float(pred[0][0])
        else:
            confidence = 0.75 
            
        results = {
            "Cat": round(1 - confidence, 4),
            "Dog": round(confidence, 4)
        }
        
        plot_data = pd.DataFrame({
            'Class': ['Cat', 'Dog'],
            'Confidence': [1 - confidence, confidence]
        })
        
        return results, plot_data
        
    except Exception as e:
        logger.error(f"Error: {str(e)}")
        return {"Error": str(e)}, pd.DataFrame({'Class': ['Cat', 'Dog'], 'Confidence': [0.5, 0.5]})


with gr.Blocks() as demo:
    gr.Markdown("# 🐾 Cat vs Dog Classifier 🦮")
    
    with gr.Row():
        with gr.Column():
            img_input = gr.Image(type="pil")
            classify_btn = gr.Button("Classify", variant="primary")
        
        with gr.Column():
            label_out = gr.Label(num_top_classes=2)
            plot_out = gr.BarPlot(
                pd.DataFrame({'Class': ['Cat', 'Dog'], 'Confidence': [0.5, 0.5]}),
                x="Class", y="Confidence", y_lim=[0,1]
            )
    
    classify_btn.click(
        classify_image,
        inputs=img_input,
        outputs=[label_out, plot_out]
    )

  
    gr.Examples(
        examples=[
            ["https://upload.wikimedia.org/wikipedia/commons/1/15/Cat_August_2010-4.jpg"],
            ["https://upload.wikimedia.org/wikipedia/commons/d/d9/Collage_of_Nine_Dogs.jpg"]
        ],
        inputs=img_input,
        outputs=[label_out, plot_out],
        fn=classify_image,
        cache_examples=True
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)