Ahmedhassan54 commited on
Commit
09fcd39
·
verified ·
1 Parent(s): ae329be

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -44
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import gradio as gr
3
  import tensorflow as tf
4
  import numpy as np
@@ -6,14 +5,13 @@ from PIL import Image
6
  from huggingface_hub import hf_hub_download
7
  import os
8
 
9
-
10
- MODEL_REPO = "Ahmedhassan54/Image-Classification"
11
  MODEL_FILE = "best_model.h5"
12
 
13
-
14
  def load_model_from_hf():
15
  try:
16
-
17
  if not os.path.exists(MODEL_FILE):
18
  print("Downloading model from Hugging Face Hub...")
19
  model_path = hf_hub_download(
@@ -21,59 +19,78 @@ def load_model_from_hf():
21
  filename=MODEL_FILE,
22
  cache_dir="."
23
  )
24
-
25
  os.system(f"cp {model_path} {MODEL_FILE}")
26
 
27
-
28
- model = tf.keras.models.load_model(MODEL_FILE)
29
- print("Model loaded successfully!")
30
- return model
31
  except Exception as e:
32
- print(f"Error loading model: {str(e)}")
33
- raise
34
-
35
 
36
  model = load_model_from_hf()
37
 
38
-
39
  def classify_image(image):
40
  try:
41
-
42
- image = image.resize((150, 150))
43
- image_array = np.array(image) / 255.0
44
- image_array = np.expand_dims(image_array, axis=0)
45
 
46
-
47
  prediction = model.predict(image_array)
48
  confidence = float(prediction[0][0])
49
 
50
- if confidence > 0.5:
51
- return {
52
- "Dog": confidence * 100,
53
- "Cat": (1 - confidence) * 100
54
- }
55
- else:
56
- return {
57
- "Cat": (1 - confidence) * 100,
58
- "Dog": confidence * 100
59
- }
60
  except Exception as e:
61
- return f"Error processing image: {str(e)}"
62
-
63
 
64
- demo = gr.Interface(
65
- fn=classify_image,
66
- inputs=gr.Image(type="pil", label="Upload Image"),
67
- outputs=gr.Label(num_top_classes=2, label="Predictions"),
68
- title="🐱 Cat vs Dog Classifier 🐶",
69
- description="Upload an image to classify whether it's a cat or dog",
70
- examples=[
71
- ["https://upload.wikimedia.org/wikipedia/commons/1/15/Cat_August_2010-4.jpg"],
72
- ["https://upload.wikimedia.org/wikipedia/commons/d/d9/Collage_of_Nine_Dogs.jpg"]
73
- ],
74
- allow_flagging="never"
75
- )
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  if __name__ == "__main__":
79
- demo.launch(debug=True, server_port=7860)
 
 
1
  import gradio as gr
2
  import tensorflow as tf
3
  import numpy as np
 
5
  from huggingface_hub import hf_hub_download
6
  import os
7
 
8
+ # Configuration
9
+ MODEL_REPO = "your_hf_username/cat-dog-classifier" # Replace with your HF username and repo
10
  MODEL_FILE = "best_model.h5"
11
 
12
+ # Download model from Hugging Face Hub
13
  def load_model_from_hf():
14
  try:
 
15
  if not os.path.exists(MODEL_FILE):
16
  print("Downloading model from Hugging Face Hub...")
17
  model_path = hf_hub_download(
 
19
  filename=MODEL_FILE,
20
  cache_dir="."
21
  )
 
22
  os.system(f"cp {model_path} {MODEL_FILE}")
23
 
24
+ return tf.keras.models.load_model(MODEL_FILE)
 
 
 
25
  except Exception as e:
26
+ raise gr.Error(f"Model loading failed: {str(e)}")
 
 
27
 
28
  model = load_model_from_hf()
29
 
 
30
  def classify_image(image):
31
  try:
32
+ image = Image.fromarray(image) if isinstance(image, np.ndarray) else image
33
+ image = image.resize((150, 150))
34
+ image_array = np.array(image) / 255.0
35
+ image_array = np.expand_dims(image_array, axis=0)
36
 
 
37
  prediction = model.predict(image_array)
38
  confidence = float(prediction[0][0])
39
 
40
+ return {
41
+ "Dog": confidence,
42
+ "Cat": 1 - confidence
43
+ }
 
 
 
 
 
 
44
  except Exception as e:
45
+ raise gr.Error(f"Classification error: {str(e)}")
 
46
 
47
+ # Custom CSS for better UI
48
+ css = """
49
+ .gradio-container {
50
+ background: linear-gradient(to right, #f5f7fa, #c3cfe2);
51
+ }
52
+ footer {
53
+ visibility: hidden
54
+ }
55
+ """
 
 
 
56
 
57
+ # Build the interface
58
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
59
+ gr.Markdown("# 🐾 Cat vs Dog Classifier 🦮")
60
+ gr.Markdown("Upload an image to classify whether it's a cat or dog")
61
+
62
+ with gr.Row():
63
+ with gr.Column():
64
+ image_input = gr.Image(label="Upload Image", type="pil")
65
+ submit_btn = gr.Button("Classify", variant="primary")
66
+
67
+ with gr.Column():
68
+ label_output = gr.Label(label="Predictions", num_top_classes=2)
69
+ confidence_bar = gr.BarPlot(
70
+ x=["Cat", "Dog"],
71
+ y=[0.5, 0.5],
72
+ y_lim=[0,1],
73
+ title="Confidence Scores",
74
+ width=400,
75
+ height=300
76
+ )
77
+
78
+ # Example images
79
+ gr.Examples(
80
+ examples=[
81
+ ["https://upload.wikimedia.org/wikipedia/commons/1/15/Cat_August_2010-4.jpg"],
82
+ ["https://upload.wikimedia.org/wikipedia/commons/d/d9/Collage_of_Nine_Dogs.jpg"]
83
+ ],
84
+ inputs=image_input
85
+ )
86
+
87
+ # Button action
88
+ submit_btn.click(
89
+ fn=classify_image,
90
+ inputs=image_input,
91
+ outputs=[label_output, confidence_bar],
92
+ api_name="classify"
93
+ )
94
 
95
  if __name__ == "__main__":
96
+ demo.launch()