File size: 3,563 Bytes
4bafa27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd90460
 
4bafa27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a1a754
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
import os
import pandas as pd
import logging

# Disable GPU if not available (for Hugging Face Spaces)
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Configuration
MODEL_REPO = "Ahmedhassan54/Image-Classification-Model"
MODEL_FILE = "best_model.keras"

# Initialize model
model = None

def load_model():
    global model
    try:
        logger.info("Downloading model...")
        model_path = hf_hub_download(
            repo_id=MODEL_REPO,
            filename=MODEL_FILE,
            cache_dir=".",
            force_download=True
        )
        logger.info(f"Model path: {model_path}")
        
        # Explicitly disable GPU
        with tf.device('/CPU:0'):
            model = tf.keras.models.load_model(model_path)
        logger.info("Model loaded successfully!")
    except Exception as e:
        logger.error(f"Model loading failed: {str(e)}")
        model = None

# Load model at startup
load_model()

def classify_image(image):
    try:
        if image is None:
            return {"Cat": 0.5, "Dog": 0.5}, pd.DataFrame({'Class': ['Cat', 'Dog'], 'Confidence': [0.5, 0.5]})
            
        # Convert to PIL Image if numpy array
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image.astype('uint8'))
            
        # Preprocess
        image = image.resize((150, 150))
        img_array = np.array(image) / 255.0
        if len(img_array.shape) == 3:
            img_array = np.expand_dims(img_array, axis=0)
            
        # Predict
        if model is not None:
            with tf.device('/CPU:0'):
                pred = model.predict(img_array, verbose=0)
            confidence = float(pred[0][0])
        else:
            confidence = 0.75  # Demo value
            
        results = {
            "Cat": round(1 - confidence, 4),
            "Dog": round(confidence, 4)
        }
        
        plot_data = pd.DataFrame({
            'Class': ['Cat', 'Dog'],
            'Confidence': [1 - confidence, confidence]
        })
        
        return results, plot_data
        
    except Exception as e:
        logger.error(f"Error: {str(e)}")
        return {"Error": str(e)}, pd.DataFrame({'Class': ['Cat', 'Dog'], 'Confidence': [0.5, 0.5]})

# Interface
with gr.Blocks() as demo:
    gr.Markdown("# 🐾 Cat vs Dog Classifier 🦮")
    
    with gr.Row():
        with gr.Column():
            img_input = gr.Image(type="pil")
            classify_btn = gr.Button("Classify", variant="primary")
        
        with gr.Column():
            label_out = gr.Label(num_top_classes=2)
            plot_out = gr.BarPlot(
                pd.DataFrame({'Class': ['Cat', 'Dog'], 'Confidence': [0.5, 0.5]}),
                x="Class", y="Confidence", y_lim=[0,1]
            )
    
    classify_btn.click(
        classify_image,
        inputs=img_input,
        outputs=[label_out, plot_out]
    )

    # Examples section
    gr.Examples(
        examples=[
            ["https://upload.wikimedia.org/wikipedia/commons/1/15/Cat_August_2010-4.jpg"],
            ["https://upload.wikimedia.org/wikipedia/commons/d/d9/Collage_of_Nine_Dogs.jpg"]
        ],
        inputs=img_input,
        outputs=[label_out, plot_out],
        fn=classify_image,
        cache_examples=True
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)