|
import gradio as gr |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
import numpy as np |
|
import os |
|
from ResNet_for_CC import CC_model |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
model_path = "CC_net.pt" |
|
model = CC_model(num_classes1=14) |
|
|
|
|
|
state_dict = torch.load(model_path, map_location=device) |
|
model.load_state_dict(state_dict, strict=False) |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
class_labels = [ |
|
"T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie", |
|
"Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress", |
|
"Vest", "Underwear" |
|
] |
|
|
|
|
|
default_images = { |
|
"Shawl": "tshirt.jpg", |
|
"Jacket": "jacket.jpg", |
|
"Sweater": "sweater.webp", |
|
"Vest": "dress.jpg" |
|
} |
|
|
|
|
|
def preprocess_image(image): |
|
"""Applies necessary transformations to the input image.""" |
|
transform = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
return transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
def classify_image(selected_default, uploaded_image): |
|
"""Processes either a default or uploaded image and returns the predicted clothing category.""" |
|
|
|
print("\n[INFO] Image selection process started.") |
|
|
|
try: |
|
|
|
if uploaded_image is not None: |
|
print("[INFO] Using uploaded image.") |
|
image = Image.fromarray(uploaded_image) |
|
else: |
|
print(f"[INFO] Using default image: {selected_default}") |
|
image_path = default_images[selected_default] |
|
image = Image.open(image_path) |
|
|
|
image = preprocess_image(image) |
|
print("[INFO] Image transformed and moved to device.") |
|
|
|
with torch.no_grad(): |
|
output = model(image) |
|
|
|
|
|
if isinstance(output, tuple): |
|
output = output[1] |
|
|
|
print(f"[DEBUG] Model output shape: {output.shape}") |
|
print(f"[DEBUG] Model output values: {output}") |
|
|
|
if output.shape[1] != 14: |
|
return f"[ERROR] Model output mismatch! Expected 14 but got {output.shape[1]}." |
|
|
|
|
|
probabilities = F.softmax(output, dim=1) |
|
print(f"[DEBUG] Softmax probabilities: {probabilities}") |
|
|
|
|
|
predicted_class = torch.argmax(probabilities, dim=1).item() |
|
print(f"[INFO] Predicted class index: {predicted_class} (Class: {class_labels[predicted_class]})") |
|
|
|
|
|
if 0 <= predicted_class < len(class_labels): |
|
predicted_label = class_labels[predicted_class] |
|
confidence = probabilities[0][predicted_class].item() * 100 |
|
return f"Predicted Class: {predicted_label} (Confidence: {confidence:.2f}%)" |
|
else: |
|
return "[ERROR] Model returned an invalid class index." |
|
|
|
except Exception as e: |
|
print(f"[ERROR] Exception during classification: {e}") |
|
return "Error in classification. Check console for details." |
|
|
|
|
|
with gr.Blocks() as interface: |
|
gr.Markdown("# Clothing1M Image Classifier") |
|
gr.Markdown("Upload a clothing image or select from the predefined images below.") |
|
|
|
|
|
default_selector = gr.Radio( |
|
choices=list(default_images.keys()), |
|
label="Select a Default Image", |
|
value="T-Shirt" |
|
) |
|
|
|
|
|
image_upload = gr.Image(type="numpy", label="Or Upload Your Own Image") |
|
|
|
|
|
output_text = gr.Textbox(label="Classification Result") |
|
|
|
|
|
classify_button = gr.Button("Classify Image") |
|
|
|
|
|
classify_button.click( |
|
fn=classify_image, |
|
inputs=[default_selector, image_upload], |
|
outputs=output_text |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
print("[INFO] Launching Gradio interface...") |
|
interface.launch() |
|
|