File size: 3,858 Bytes
4a80f4f
 
 
 
 
c7f845c
4a80f4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b95f530
c7f845c
385c8a3
b95f530
 
574ea78
c7f845c
 
57b4f24
 
491fff3
b95f530
4a80f4f
 
 
 
 
 
 
 
 
 
b95f530
c7f845c
 
4a80f4f
c7f845c
 
57b4f24
c7f845c
 
57b4f24
4a80f4f
57b4f24
 
4a80f4f
 
 
57b4f24
4a80f4f
 
 
 
 
 
 
 
 
 
57b4f24
4a80f4f
57b4f24
4a80f4f
b95f530
c7f845c
 
 
 
57b4f24
491fff3
57b4f24
 
 
491fff3
 
57b4f24
 
c7f845c
 
491fff3
c7f845c
 
57b4f24
c7f845c
 
57b4f24
c7f845c
 
57b4f24
c7f845c
 
57b4f24
c7f845c
 
 
 
 
4a80f4f
b95f530
4a80f4f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import os
from ResNet_for_CC import CC_model  # Import the model

# Set device (CPU/GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the trained CC_model
model_path = "CC_net.pt"
model = CC_model(num_classes1=14)

# Load model weights
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict, strict=False)
model.to(device)
model.eval()

# Clothing1M Class Labels
class_labels = [
    "T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie",
    "Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress",
    "Vest", "Underwear"
]

# **Predefined Default Images**
default_images = {
    "Shawl": "shawlOG.webp",
    "Jacket": "jacket.jpg",
    "Sweater": "sweater.webp",
    "Vest": "dress.jpg"
}

# Convert to gallery format (list of (image_path, caption) tuples)
default_images_gallery = [(path, label) for label, path in default_images.items()]

# **Image Preprocessing Function**
def preprocess_image(image):
    """Applies necessary transformations to the input image."""
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return transform(image).unsqueeze(0).to(device)

# **Classification Function**
def classify_image(selected_default, uploaded_image):
    """Processes either a default or uploaded image and returns the predicted clothing category."""
    try:
        # Use the uploaded image if provided; otherwise, use the selected default image
        if uploaded_image is not None:
            image = Image.fromarray(uploaded_image)
        else:
            image_path = default_images[selected_default]
            image = Image.open(image_path)

        image = preprocess_image(image)
        
        with torch.no_grad():
            output = model(image)
            if isinstance(output, tuple):
                output = output[1]

            probabilities = F.softmax(output, dim=1)
            predicted_class = torch.argmax(probabilities, dim=1).item()

            if 0 <= predicted_class < len(class_labels):
                predicted_label = class_labels[predicted_class]
                confidence = probabilities[0][predicted_class].item() * 100
                return f"Predicted Class: {predicted_label} (Confidence: {confidence:.2f}%)"
            else:
                return "[ERROR] Model returned an invalid class index."
    
    except Exception as e:
        return f"Error in classification: {e}"

# **Gradio Interface**
with gr.Blocks() as interface:
    gr.Markdown("# Clothing1M Image Classifier")
    gr.Markdown("Upload a clothing image or select from the predefined images below.")

    # Gallery to display default images
    gallery = gr.Gallery(
        value=default_images_gallery,  # Provide list of (image, caption) tuples
        label="Default Images",
        elem_id="default_gallery"
    )

    # Default Image Selection
    default_selector = gr.Dropdown(
        choices=list(default_images.keys()),
        label="Select a Default Image",
        value="Shawl"
    )

    # File Upload Option
    image_upload = gr.Image(type="numpy", label="Or Upload Your Own Image")

    # Output Text
    output_text = gr.Textbox(label="Classification Result")

    # Classify Button
    classify_button = gr.Button("Classify Image")

    # Define Action
    classify_button.click(
        fn=classify_image,
        inputs=[default_selector, image_upload],
        outputs=output_text
    )

# **Run the Interface**
if __name__ == "__main__":
    print("[INFO] Launching Gradio interface...")
    interface.launch()