Ahmedhassan54 commited on
Commit
c96e0d1
·
verified ·
1 Parent(s): a108e44

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -96
app.py CHANGED
@@ -1,96 +1,118 @@
1
- import gradio as gr
2
- import tensorflow as tf
3
- import numpy as np
4
- from PIL import Image
5
- from huggingface_hub import hf_hub_download
6
- import os
7
-
8
- # Configuration
9
- MODEL_REPO = "Ahmedhassan54/Image-Classification" # Replace with your HF username and repo
10
- MODEL_FILE = "best_model.h5"
11
-
12
- # Download model from Hugging Face Hub
13
- def load_model_from_hf():
14
- try:
15
- if not os.path.exists(MODEL_FILE):
16
- print("Downloading model from Hugging Face Hub...")
17
- model_path = hf_hub_download(
18
- repo_id=MODEL_REPO,
19
- filename=MODEL_FILE,
20
- cache_dir="."
21
- )
22
- os.system(f"cp {model_path} {MODEL_FILE}")
23
-
24
- return tf.keras.models.load_model(MODEL_FILE)
25
- except Exception as e:
26
- raise gr.Error(f"Model loading failed: {str(e)}")
27
-
28
- model = load_model_from_hf()
29
-
30
- def classify_image(image):
31
- try:
32
- image = Image.fromarray(image) if isinstance(image, np.ndarray) else image
33
- image = image.resize((150, 150))
34
- image_array = np.array(image) / 255.0
35
- image_array = np.expand_dims(image_array, axis=0)
36
-
37
- prediction = model.predict(image_array)
38
- confidence = float(prediction[0][0])
39
-
40
- return {
41
- "Dog": confidence,
42
- "Cat": 1 - confidence
43
- }
44
- except Exception as e:
45
- raise gr.Error(f"Classification error: {str(e)}")
46
-
47
- # Custom CSS for better UI
48
- css = """
49
- .gradio-container {
50
- background: linear-gradient(to right, #f5f7fa, #c3cfe2);
51
- }
52
- footer {
53
- visibility: hidden
54
- }
55
- """
56
-
57
- # Build the interface
58
- with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
59
- gr.Markdown("# 🐾 Cat vs Dog Classifier 🦮")
60
- gr.Markdown("Upload an image to classify whether it's a cat or dog")
61
-
62
- with gr.Row():
63
- with gr.Column():
64
- image_input = gr.Image(label="Upload Image", type="pil")
65
- submit_btn = gr.Button("Classify", variant="primary")
66
-
67
- with gr.Column():
68
- label_output = gr.Label(label="Predictions", num_top_classes=2)
69
- confidence_bar = gr.BarPlot(
70
- x=["Cat", "Dog"],
71
- y=[0.5, 0.5],
72
- y_lim=[0,1],
73
- title="Confidence Scores",
74
- width=400,
75
- height=300
76
- )
77
-
78
- # Example images
79
- gr.Examples(
80
- examples=[
81
- ["https://upload.wikimedia.org/wikipedia/commons/1/15/Cat_August_2010-4.jpg"],
82
- ["https://upload.wikimedia.org/wikipedia/commons/d/d9/Collage_of_Nine_Dogs.jpg"]
83
- ],
84
- inputs=image_input
85
- )
86
-
87
- # Button action
88
- submit_btn.click(
89
- fn=classify_image,
90
- inputs=image_input,
91
- outputs=[label_output, confidence_bar],
92
- api_name="classify"
93
- )
94
-
95
- if __name__ == "__main__":
96
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ from PIL import Image
5
+ from huggingface_hub import hf_hub_download
6
+ import os
7
+ import pandas as pd
8
+
9
+ # Configuration
10
+ MODEL_REPO = "Ahmedhassan54/Image-Classification" # Changed to your actual repo
11
+ MODEL_FILE = "best_model.h5"
12
+
13
+ # Download model from Hugging Face Hub
14
+ def load_model_from_hf():
15
+ try:
16
+ if not os.path.exists(MODEL_FILE):
17
+ print("Downloading model from Hugging Face Hub...")
18
+ model_path = hf_hub_download(
19
+ repo_id=MODEL_REPO,
20
+ filename=MODEL_FILE,
21
+ cache_dir="."
22
+ )
23
+ os.system(f"cp {model_path} {MODEL_FILE}")
24
+
25
+ return tf.keras.models.load_model(MODEL_FILE)
26
+ except Exception as e:
27
+ raise gr.Error(f"Model loading failed: {str(e)}")
28
+
29
+ model = load_model_from_hf()
30
+
31
+ def classify_image(image):
32
+ try:
33
+ # Convert image if needed
34
+ if isinstance(image, np.ndarray):
35
+ image = Image.fromarray(image)
36
+
37
+ # Preprocess image
38
+ image = image.resize((150, 150))
39
+ image_array = np.array(image) / 255.0
40
+ image_array = np.expand_dims(image_array, axis=0)
41
+
42
+ # Make prediction
43
+ prediction = model.predict(image_array, verbose=0)
44
+ confidence = float(prediction[0][0])
45
+
46
+ # Format outputs
47
+ label_output = {
48
+ "Cat": 1 - confidence,
49
+ "Dog": confidence
50
+ }
51
+
52
+ # Create dataframe for bar plot
53
+ plot_data = pd.DataFrame({
54
+ 'Class': ['Cat', 'Dog'],
55
+ 'Confidence': [1 - confidence, confidence]
56
+ })
57
+
58
+ return label_output, plot_data
59
+
60
+ except Exception as e:
61
+ print(f"Error: {str(e)}") # Debug print
62
+ raise gr.Error(f"Classification error: {str(e)}")
63
+
64
+ # Custom CSS
65
+ css = """
66
+ .gradio-container {
67
+ background: linear-gradient(to right, #f5f7fa, #c3cfe2);
68
+ }
69
+ footer {
70
+ visibility: hidden
71
+ }
72
+ """
73
+
74
+ # Build the interface
75
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
76
+ gr.Markdown("# 🐾 Cat vs Dog Classifier 🦮")
77
+ gr.Markdown("Upload an image to classify whether it's a cat or dog")
78
+
79
+ with gr.Row():
80
+ with gr.Column():
81
+ image_input = gr.Image(label="Upload Image", type="pil")
82
+ submit_btn = gr.Button("Classify", variant="primary")
83
+
84
+ with gr.Column():
85
+ label_output = gr.Label(label="Predictions", num_top_classes=2)
86
+ confidence_bar = gr.BarPlot(
87
+ pd.DataFrame({'Class': ['Cat', 'Dog'], 'Confidence': [0.5, 0.5]}),
88
+ x="Class",
89
+ y="Confidence",
90
+ y_lim=[0,1],
91
+ title="Confidence Scores",
92
+ width=400,
93
+ height=300,
94
+ container=False
95
+ )
96
+
97
+ # Example images
98
+ gr.Examples(
99
+ examples=[
100
+ ["https://upload.wikimedia.org/wikipedia/commons/1/15/Cat_August_2010-4.jpg"],
101
+ ["https://upload.wikimedia.org/wikipedia/commons/d/d9/Collage_of_Nine_Dogs.jpg"]
102
+ ],
103
+ inputs=image_input,
104
+ outputs=[label_output, confidence_bar],
105
+ fn=classify_image,
106
+ cache_examples=True
107
+ )
108
+
109
+ # Button action
110
+ submit_btn.click(
111
+ fn=classify_image,
112
+ inputs=image_input,
113
+ outputs=[label_output, confidence_bar],
114
+ api_name="classify"
115
+ )
116
+
117
+ if __name__ == "__main__":
118
+ demo.launch(debug=True)