Moditha24 commited on
Commit
57b4f24
·
verified ·
1 Parent(s): 491fff3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -43
app.py CHANGED
@@ -1,10 +1,8 @@
1
  import gradio as gr
2
  import torch
3
- import torch.nn as nn
4
  import torch.nn.functional as F
5
  import torchvision.transforms as transforms
6
  from PIL import Image
7
- import numpy as np
8
  import os
9
  from ResNet_for_CC import CC_model # Import the model
10
 
@@ -36,8 +34,8 @@ default_images = {
36
  "Vest": "dress.jpg"
37
  }
38
 
39
- # Convert image paths to a format suitable for Gradio
40
- default_image_paths = [[img_path] for img_path in default_images.values()]
41
 
42
  # **Image Preprocessing Function**
43
  def preprocess_image(image):
@@ -53,85 +51,63 @@ def preprocess_image(image):
53
  # **Classification Function**
54
  def classify_image(selected_default, uploaded_image):
55
  """Processes either a default or uploaded image and returns the predicted clothing category."""
56
-
57
- print("\n[INFO] Image selection process started.")
58
-
59
  try:
60
  # Use the uploaded image if provided; otherwise, use the selected default image
61
  if uploaded_image is not None:
62
- print("[INFO] Using uploaded image.")
63
- image = Image.fromarray(uploaded_image) # Ensure conversion to PIL format
64
  else:
65
- print(f"[INFO] Using default image: {selected_default}")
66
  image_path = default_images[selected_default]
67
- image = Image.open(image_path) # Load the selected default image
68
-
69
- image = preprocess_image(image) # Apply transformations
70
- print("[INFO] Image transformed and moved to device.")
71
 
 
 
72
  with torch.no_grad():
73
  output = model(image)
74
-
75
- # Ensure output is a tensor (handle tuple case)
76
  if isinstance(output, tuple):
77
- output = output[1] # Extract the actual output tensor
78
-
79
- print(f"[DEBUG] Model output shape: {output.shape}")
80
- print(f"[DEBUG] Model output values: {output}")
81
 
82
- if output.shape[1] != 14:
83
- return f"[ERROR] Model output mismatch! Expected 14 but got {output.shape[1]}."
84
-
85
- # Convert logits to probabilities
86
  probabilities = F.softmax(output, dim=1)
87
- print(f"[DEBUG] Softmax probabilities: {probabilities}")
88
-
89
- # Get predicted class index
90
  predicted_class = torch.argmax(probabilities, dim=1).item()
91
- print(f"[INFO] Predicted class index: {predicted_class} (Class: {class_labels[predicted_class]})")
92
 
93
- # Validate and return the prediction
94
  if 0 <= predicted_class < len(class_labels):
95
  predicted_label = class_labels[predicted_class]
96
  confidence = probabilities[0][predicted_class].item() * 100
97
  return f"Predicted Class: {predicted_label} (Confidence: {confidence:.2f}%)"
98
  else:
99
  return "[ERROR] Model returned an invalid class index."
100
-
101
  except Exception as e:
102
- print(f"[ERROR] Exception during classification: {e}")
103
- return "Error in classification. Check console for details."
104
 
105
  # **Gradio Interface**
106
  with gr.Blocks() as interface:
107
  gr.Markdown("# Clothing1M Image Classifier")
108
  gr.Markdown("Upload a clothing image or select from the predefined images below.")
109
 
110
- # **Default Image Gallery for Selection**
111
- gr.Markdown("### Select a Default Image:")
112
  gallery = gr.Gallery(
113
- value=default_image_paths,
114
- label="Available Default Images",
115
- show_label=True
116
  )
117
 
118
- # **Default Image Selection Radio**
119
- default_selector = gr.Radio(
120
  choices=list(default_images.keys()),
121
  label="Select a Default Image",
122
  value="Shawl"
123
  )
124
 
125
- # **File Upload Option**
126
  image_upload = gr.Image(type="numpy", label="Or Upload Your Own Image")
127
 
128
- # **Output Text**
129
  output_text = gr.Textbox(label="Classification Result")
130
 
131
- # **Classify Button**
132
  classify_button = gr.Button("Classify Image")
133
 
134
- # **Define Action**
135
  classify_button.click(
136
  fn=classify_image,
137
  inputs=[default_selector, image_upload],
 
1
  import gradio as gr
2
  import torch
 
3
  import torch.nn.functional as F
4
  import torchvision.transforms as transforms
5
  from PIL import Image
 
6
  import os
7
  from ResNet_for_CC import CC_model # Import the model
8
 
 
34
  "Vest": "dress.jpg"
35
  }
36
 
37
+ # Convert to gallery format (list of (image_path, caption) tuples)
38
+ default_images_gallery = [(path, label) for label, path in default_images.items()]
39
 
40
  # **Image Preprocessing Function**
41
  def preprocess_image(image):
 
51
  # **Classification Function**
52
  def classify_image(selected_default, uploaded_image):
53
  """Processes either a default or uploaded image and returns the predicted clothing category."""
 
 
 
54
  try:
55
  # Use the uploaded image if provided; otherwise, use the selected default image
56
  if uploaded_image is not None:
57
+ image = Image.fromarray(uploaded_image)
 
58
  else:
 
59
  image_path = default_images[selected_default]
60
+ image = Image.open(image_path)
 
 
 
61
 
62
+ image = preprocess_image(image)
63
+
64
  with torch.no_grad():
65
  output = model(image)
 
 
66
  if isinstance(output, tuple):
67
+ output = output[1]
 
 
 
68
 
 
 
 
 
69
  probabilities = F.softmax(output, dim=1)
 
 
 
70
  predicted_class = torch.argmax(probabilities, dim=1).item()
 
71
 
 
72
  if 0 <= predicted_class < len(class_labels):
73
  predicted_label = class_labels[predicted_class]
74
  confidence = probabilities[0][predicted_class].item() * 100
75
  return f"Predicted Class: {predicted_label} (Confidence: {confidence:.2f}%)"
76
  else:
77
  return "[ERROR] Model returned an invalid class index."
78
+
79
  except Exception as e:
80
+ return f"Error in classification: {e}"
 
81
 
82
  # **Gradio Interface**
83
  with gr.Blocks() as interface:
84
  gr.Markdown("# Clothing1M Image Classifier")
85
  gr.Markdown("Upload a clothing image or select from the predefined images below.")
86
 
87
+ # Gallery to display default images
 
88
  gallery = gr.Gallery(
89
+ value=default_images_gallery, # Provide list of (image, caption) tuples
90
+ label="Default Images",
91
+ elem_id="default_gallery"
92
  )
93
 
94
+ # Default Image Selection
95
+ default_selector = gr.Dropdown(
96
  choices=list(default_images.keys()),
97
  label="Select a Default Image",
98
  value="Shawl"
99
  )
100
 
101
+ # File Upload Option
102
  image_upload = gr.Image(type="numpy", label="Or Upload Your Own Image")
103
 
104
+ # Output Text
105
  output_text = gr.Textbox(label="Classification Result")
106
 
107
+ # Classify Button
108
  classify_button = gr.Button("Classify Image")
109
 
110
+ # Define Action
111
  classify_button.click(
112
  fn=classify_image,
113
  inputs=[default_selector, image_upload],