Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,8 @@ import torch.nn as nn
|
|
4 |
import torch.nn.functional as F
|
5 |
import torchvision.transforms as transforms
|
6 |
from PIL import Image
|
|
|
|
|
7 |
from ResNet_for_CC import CC_model # Import the model
|
8 |
|
9 |
# Set device (CPU/GPU)
|
@@ -26,7 +28,15 @@ class_labels = [
|
|
26 |
"Vest", "Underwear"
|
27 |
]
|
28 |
|
29 |
-
# β
**
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
def preprocess_image(image):
|
31 |
"""Applies necessary transformations to the input image."""
|
32 |
transform = transforms.Compose([
|
@@ -38,18 +48,27 @@ def preprocess_image(image):
|
|
38 |
return transform(image).unsqueeze(0).to(device)
|
39 |
|
40 |
# β
**Classification Function**
|
41 |
-
def classify_image(
|
42 |
-
"""Processes
|
43 |
-
print("\n[INFO] Received image for classification.")
|
44 |
|
|
|
|
|
45 |
try:
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
image = preprocess_image(image) # Apply transformations
|
48 |
print("[INFO] Image transformed and moved to device.")
|
49 |
|
50 |
with torch.no_grad():
|
51 |
output = model(image)
|
52 |
-
|
53 |
# β
Ensure output is a tensor (handle tuple case)
|
54 |
if isinstance(output, tuple):
|
55 |
output = output[1] # Extract the actual output tensor
|
@@ -81,13 +100,32 @@ def classify_image(image):
|
|
81 |
return "Error in classification. Check console for details."
|
82 |
|
83 |
# β
**Gradio Interface**
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
# β
**Run the Interface**
|
93 |
if __name__ == "__main__":
|
|
|
4 |
import torch.nn.functional as F
|
5 |
import torchvision.transforms as transforms
|
6 |
from PIL import Image
|
7 |
+
import numpy as np
|
8 |
+
import os
|
9 |
from ResNet_for_CC import CC_model # Import the model
|
10 |
|
11 |
# Set device (CPU/GPU)
|
|
|
28 |
"Vest", "Underwear"
|
29 |
]
|
30 |
|
31 |
+
# β
**Predefined Default Images**
|
32 |
+
default_images = {
|
33 |
+
"T-Shirt": "default_images/tshirt.jpg",
|
34 |
+
"Jacket": "default_images/jacket.jpg",
|
35 |
+
"Sweater": "default_images/sweater.jpg",
|
36 |
+
"Dress": "default_images/dress.jpg"
|
37 |
+
}
|
38 |
+
|
39 |
+
# β
**Image Preprocessing Function**
|
40 |
def preprocess_image(image):
|
41 |
"""Applies necessary transformations to the input image."""
|
42 |
transform = transforms.Compose([
|
|
|
48 |
return transform(image).unsqueeze(0).to(device)
|
49 |
|
50 |
# β
**Classification Function**
|
51 |
+
def classify_image(selected_default, uploaded_image):
|
52 |
+
"""Processes either a default or uploaded image and returns the predicted clothing category."""
|
|
|
53 |
|
54 |
+
print("\n[INFO] Image selection process started.")
|
55 |
+
|
56 |
try:
|
57 |
+
# Use the uploaded image if provided; otherwise, use the selected default image
|
58 |
+
if uploaded_image is not None:
|
59 |
+
print("[INFO] Using uploaded image.")
|
60 |
+
image = Image.fromarray(uploaded_image) # Ensure conversion to PIL format
|
61 |
+
else:
|
62 |
+
print(f"[INFO] Using default image: {selected_default}")
|
63 |
+
image_path = default_images[selected_default]
|
64 |
+
image = Image.open(image_path) # Load the selected default image
|
65 |
+
|
66 |
image = preprocess_image(image) # Apply transformations
|
67 |
print("[INFO] Image transformed and moved to device.")
|
68 |
|
69 |
with torch.no_grad():
|
70 |
output = model(image)
|
71 |
+
|
72 |
# β
Ensure output is a tensor (handle tuple case)
|
73 |
if isinstance(output, tuple):
|
74 |
output = output[1] # Extract the actual output tensor
|
|
|
100 |
return "Error in classification. Check console for details."
|
101 |
|
102 |
# β
**Gradio Interface**
|
103 |
+
with gr.Blocks() as interface:
|
104 |
+
gr.Markdown("# Clothing1M Image Classifier")
|
105 |
+
gr.Markdown("Upload a clothing image or select from the predefined images below.")
|
106 |
+
|
107 |
+
# Default Image Selection
|
108 |
+
default_selector = gr.Radio(
|
109 |
+
choices=list(default_images.keys()),
|
110 |
+
label="Select a Default Image",
|
111 |
+
value="T-Shirt"
|
112 |
+
)
|
113 |
+
|
114 |
+
# File Upload Option
|
115 |
+
image_upload = gr.Image(type="numpy", label="Or Upload Your Own Image")
|
116 |
+
|
117 |
+
# Output Text
|
118 |
+
output_text = gr.Textbox(label="Classification Result")
|
119 |
+
|
120 |
+
# Classify Button
|
121 |
+
classify_button = gr.Button("Classify Image")
|
122 |
+
|
123 |
+
# Define Action
|
124 |
+
classify_button.click(
|
125 |
+
fn=classify_image,
|
126 |
+
inputs=[default_selector, image_upload],
|
127 |
+
outputs=output_text
|
128 |
+
)
|
129 |
|
130 |
# β
**Run the Interface**
|
131 |
if __name__ == "__main__":
|