File size: 4,993 Bytes
4a80f4f
 
 
 
 
 
c7f845c
 
4a80f4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b95f530
c7f845c
385c8a3
b95f530
 
574ea78
c7f845c
 
491fff3
 
 
b95f530
4a80f4f
 
 
 
 
 
 
 
 
 
b95f530
c7f845c
 
4a80f4f
c7f845c
 
4a80f4f
c7f845c
 
 
 
 
 
 
 
 
4a80f4f
 
 
 
 
c7f845c
b95f530
4a80f4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b95f530
c7f845c
 
 
 
491fff3
 
 
 
 
 
 
 
 
c7f845c
 
 
491fff3
c7f845c
 
491fff3
c7f845c
 
491fff3
c7f845c
 
491fff3
c7f845c
 
491fff3
c7f845c
 
 
 
 
4a80f4f
b95f530
4a80f4f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import os
from ResNet_for_CC import CC_model  # Import the model

# Set device (CPU/GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the trained CC_model
model_path = "CC_net.pt"
model = CC_model(num_classes1=14)

# Load model weights
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict, strict=False)
model.to(device)
model.eval()

# Clothing1M Class Labels
class_labels = [
    "T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie",
    "Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress",
    "Vest", "Underwear"
]

# **Predefined Default Images**
default_images = {
    "Shawl": "shawlOG.webp",
    "Jacket": "jacket.jpg",
    "Sweater": "sweater.webp",
    "Vest": "dress.jpg"
}

# Convert image paths to a format suitable for Gradio
default_image_paths = [[img_path] for img_path in default_images.values()]

# **Image Preprocessing Function**
def preprocess_image(image):
    """Applies necessary transformations to the input image."""
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return transform(image).unsqueeze(0).to(device)

# **Classification Function**
def classify_image(selected_default, uploaded_image):
    """Processes either a default or uploaded image and returns the predicted clothing category."""

    print("\n[INFO] Image selection process started.")
    
    try:
        # Use the uploaded image if provided; otherwise, use the selected default image
        if uploaded_image is not None:
            print("[INFO] Using uploaded image.")
            image = Image.fromarray(uploaded_image)  # Ensure conversion to PIL format
        else:
            print(f"[INFO] Using default image: {selected_default}")
            image_path = default_images[selected_default]
            image = Image.open(image_path)  # Load the selected default image

        image = preprocess_image(image)  # Apply transformations
        print("[INFO] Image transformed and moved to device.")

        with torch.no_grad():
            output = model(image)

            # Ensure output is a tensor (handle tuple case)
            if isinstance(output, tuple):
                output = output[1]  # Extract the actual output tensor

            print(f"[DEBUG] Model output shape: {output.shape}")
            print(f"[DEBUG] Model output values: {output}")

            if output.shape[1] != 14:
                return f"[ERROR] Model output mismatch! Expected 14 but got {output.shape[1]}."

            # Convert logits to probabilities
            probabilities = F.softmax(output, dim=1)
            print(f"[DEBUG] Softmax probabilities: {probabilities}")

            # Get predicted class index
            predicted_class = torch.argmax(probabilities, dim=1).item()
            print(f"[INFO] Predicted class index: {predicted_class} (Class: {class_labels[predicted_class]})")

            # Validate and return the prediction
            if 0 <= predicted_class < len(class_labels):
                predicted_label = class_labels[predicted_class]
                confidence = probabilities[0][predicted_class].item() * 100
                return f"Predicted Class: {predicted_label} (Confidence: {confidence:.2f}%)"
            else:
                return "[ERROR] Model returned an invalid class index."

    except Exception as e:
        print(f"[ERROR] Exception during classification: {e}")
        return "Error in classification. Check console for details."

# **Gradio Interface**
with gr.Blocks() as interface:
    gr.Markdown("# Clothing1M Image Classifier")
    gr.Markdown("Upload a clothing image or select from the predefined images below.")

    # **Default Image Gallery for Selection**
    gr.Markdown("### Select a Default Image:")
    gallery = gr.Gallery(
        value=default_image_paths,
        label="Available Default Images",
        show_label=True
    )

    # **Default Image Selection Radio**
    default_selector = gr.Radio(
        choices=list(default_images.keys()),
        label="Select a Default Image",
        value="Shawl"
    )

    # **File Upload Option**
    image_upload = gr.Image(type="numpy", label="Or Upload Your Own Image")

    # **Output Text**
    output_text = gr.Textbox(label="Classification Result")

    # **Classify Button**
    classify_button = gr.Button("Classify Image")

    # **Define Action**
    classify_button.click(
        fn=classify_image,
        inputs=[default_selector, image_upload],
        outputs=output_text
    )

# **Run the Interface**
if __name__ == "__main__":
    print("[INFO] Launching Gradio interface...")
    interface.launch()