Moditha24 commited on
Commit
c7f845c
Β·
verified Β·
1 Parent(s): 83b90f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -13
app.py CHANGED
@@ -4,6 +4,8 @@ import torch.nn as nn
4
  import torch.nn.functional as F
5
  import torchvision.transforms as transforms
6
  from PIL import Image
 
 
7
  from ResNet_for_CC import CC_model # Import the model
8
 
9
  # Set device (CPU/GPU)
@@ -26,7 +28,15 @@ class_labels = [
26
  "Vest", "Underwear"
27
  ]
28
 
29
- # βœ… **Updated Image Preprocessing Function**
 
 
 
 
 
 
 
 
30
  def preprocess_image(image):
31
  """Applies necessary transformations to the input image."""
32
  transform = transforms.Compose([
@@ -38,18 +48,27 @@ def preprocess_image(image):
38
  return transform(image).unsqueeze(0).to(device)
39
 
40
  # βœ… **Classification Function**
41
- def classify_image(image):
42
- """Processes the input image and returns the predicted clothing category."""
43
- print("\n[INFO] Received image for classification.")
44
 
 
 
45
  try:
46
- image = Image.fromarray(image) # Ensure conversion to PIL format
 
 
 
 
 
 
 
 
47
  image = preprocess_image(image) # Apply transformations
48
  print("[INFO] Image transformed and moved to device.")
49
 
50
  with torch.no_grad():
51
  output = model(image)
52
-
53
  # βœ… Ensure output is a tensor (handle tuple case)
54
  if isinstance(output, tuple):
55
  output = output[1] # Extract the actual output tensor
@@ -81,13 +100,32 @@ def classify_image(image):
81
  return "Error in classification. Check console for details."
82
 
83
  # βœ… **Gradio Interface**
84
- interface = gr.Interface(
85
- fn=classify_image,
86
- inputs=gr.Image(type="numpy"),
87
- outputs="text",
88
- title="Clothing1M Image Classifier",
89
- description="Upload a clothing image, and the model will classify it into one of the 14 categories."
90
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  # βœ… **Run the Interface**
93
  if __name__ == "__main__":
 
4
  import torch.nn.functional as F
5
  import torchvision.transforms as transforms
6
  from PIL import Image
7
+ import numpy as np
8
+ import os
9
  from ResNet_for_CC import CC_model # Import the model
10
 
11
  # Set device (CPU/GPU)
 
28
  "Vest", "Underwear"
29
  ]
30
 
31
+ # βœ… **Predefined Default Images**
32
+ default_images = {
33
+ "T-Shirt": "default_images/tshirt.jpg",
34
+ "Jacket": "default_images/jacket.jpg",
35
+ "Sweater": "default_images/sweater.jpg",
36
+ "Dress": "default_images/dress.jpg"
37
+ }
38
+
39
+ # βœ… **Image Preprocessing Function**
40
  def preprocess_image(image):
41
  """Applies necessary transformations to the input image."""
42
  transform = transforms.Compose([
 
48
  return transform(image).unsqueeze(0).to(device)
49
 
50
  # βœ… **Classification Function**
51
+ def classify_image(selected_default, uploaded_image):
52
+ """Processes either a default or uploaded image and returns the predicted clothing category."""
 
53
 
54
+ print("\n[INFO] Image selection process started.")
55
+
56
  try:
57
+ # Use the uploaded image if provided; otherwise, use the selected default image
58
+ if uploaded_image is not None:
59
+ print("[INFO] Using uploaded image.")
60
+ image = Image.fromarray(uploaded_image) # Ensure conversion to PIL format
61
+ else:
62
+ print(f"[INFO] Using default image: {selected_default}")
63
+ image_path = default_images[selected_default]
64
+ image = Image.open(image_path) # Load the selected default image
65
+
66
  image = preprocess_image(image) # Apply transformations
67
  print("[INFO] Image transformed and moved to device.")
68
 
69
  with torch.no_grad():
70
  output = model(image)
71
+
72
  # βœ… Ensure output is a tensor (handle tuple case)
73
  if isinstance(output, tuple):
74
  output = output[1] # Extract the actual output tensor
 
100
  return "Error in classification. Check console for details."
101
 
102
  # βœ… **Gradio Interface**
103
+ with gr.Blocks() as interface:
104
+ gr.Markdown("# Clothing1M Image Classifier")
105
+ gr.Markdown("Upload a clothing image or select from the predefined images below.")
106
+
107
+ # Default Image Selection
108
+ default_selector = gr.Radio(
109
+ choices=list(default_images.keys()),
110
+ label="Select a Default Image",
111
+ value="T-Shirt"
112
+ )
113
+
114
+ # File Upload Option
115
+ image_upload = gr.Image(type="numpy", label="Or Upload Your Own Image")
116
+
117
+ # Output Text
118
+ output_text = gr.Textbox(label="Classification Result")
119
+
120
+ # Classify Button
121
+ classify_button = gr.Button("Classify Image")
122
+
123
+ # Define Action
124
+ classify_button.click(
125
+ fn=classify_image,
126
+ inputs=[default_selector, image_upload],
127
+ outputs=output_text
128
+ )
129
 
130
  # βœ… **Run the Interface**
131
  if __name__ == "__main__":