File size: 4,610 Bytes
4a80f4f
 
 
 
 
 
c7f845c
 
4a80f4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b95f530
c7f845c
574ea78
b95f530
 
574ea78
c7f845c
 
b95f530
4a80f4f
 
 
 
 
 
 
 
 
 
b95f530
c7f845c
 
4a80f4f
c7f845c
 
4a80f4f
c7f845c
 
 
 
 
 
 
 
 
4a80f4f
 
 
 
 
c7f845c
b95f530
4a80f4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b95f530
c7f845c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a80f4f
b95f530
4a80f4f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import os
from ResNet_for_CC import CC_model  # Import the model

# Set device (CPU/GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the trained CC_model
model_path = "CC_net.pt"
model = CC_model(num_classes1=14)

# Load model weights
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict, strict=False)
model.to(device)
model.eval()

# Clothing1M Class Labels
class_labels = [
    "T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie",
    "Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress",
    "Vest", "Underwear"
]

# **Predefined Default Images**
default_images = {
    "Shawl": "tshirt.jpg",
    "Jacket": "jacket.jpg",
    "Sweater": "sweater.webp",
    "Vest": "dress.jpg"
}

# **Image Preprocessing Function**
def preprocess_image(image):
    """Applies necessary transformations to the input image."""
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return transform(image).unsqueeze(0).to(device)

# **Classification Function**
def classify_image(selected_default, uploaded_image):
    """Processes either a default or uploaded image and returns the predicted clothing category."""

    print("\n[INFO] Image selection process started.")
    
    try:
        # Use the uploaded image if provided; otherwise, use the selected default image
        if uploaded_image is not None:
            print("[INFO] Using uploaded image.")
            image = Image.fromarray(uploaded_image)  # Ensure conversion to PIL format
        else:
            print(f"[INFO] Using default image: {selected_default}")
            image_path = default_images[selected_default]
            image = Image.open(image_path)  # Load the selected default image

        image = preprocess_image(image)  # Apply transformations
        print("[INFO] Image transformed and moved to device.")

        with torch.no_grad():
            output = model(image)

            # Ensure output is a tensor (handle tuple case)
            if isinstance(output, tuple):
                output = output[1]  # Extract the actual output tensor

            print(f"[DEBUG] Model output shape: {output.shape}")
            print(f"[DEBUG] Model output values: {output}")

            if output.shape[1] != 14:
                return f"[ERROR] Model output mismatch! Expected 14 but got {output.shape[1]}."

            # Convert logits to probabilities
            probabilities = F.softmax(output, dim=1)
            print(f"[DEBUG] Softmax probabilities: {probabilities}")

            # Get predicted class index
            predicted_class = torch.argmax(probabilities, dim=1).item()
            print(f"[INFO] Predicted class index: {predicted_class} (Class: {class_labels[predicted_class]})")

            # Validate and return the prediction
            if 0 <= predicted_class < len(class_labels):
                predicted_label = class_labels[predicted_class]
                confidence = probabilities[0][predicted_class].item() * 100
                return f"Predicted Class: {predicted_label} (Confidence: {confidence:.2f}%)"
            else:
                return "[ERROR] Model returned an invalid class index."

    except Exception as e:
        print(f"[ERROR] Exception during classification: {e}")
        return "Error in classification. Check console for details."

# **Gradio Interface**
with gr.Blocks() as interface:
    gr.Markdown("# Clothing1M Image Classifier")
    gr.Markdown("Upload a clothing image or select from the predefined images below.")

    # Default Image Selection
    default_selector = gr.Radio(
        choices=list(default_images.keys()),
        label="Select a Default Image",
        value="T-Shirt"
    )

    # File Upload Option
    image_upload = gr.Image(type="numpy", label="Or Upload Your Own Image")

    # Output Text
    output_text = gr.Textbox(label="Classification Result")

    # Classify Button
    classify_button = gr.Button("Classify Image")

    # Define Action
    classify_button.click(
        fn=classify_image,
        inputs=[default_selector, image_upload],
        outputs=output_text
    )

# **Run the Interface**
if __name__ == "__main__":
    print("[INFO] Launching Gradio interface...")
    interface.launch()