File size: 4,610 Bytes
4a80f4f c7f845c 4a80f4f b95f530 c7f845c 574ea78 b95f530 574ea78 c7f845c b95f530 4a80f4f b95f530 c7f845c 4a80f4f c7f845c 4a80f4f c7f845c 4a80f4f c7f845c b95f530 4a80f4f b95f530 c7f845c 4a80f4f b95f530 4a80f4f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import os
from ResNet_for_CC import CC_model # Import the model
# Set device (CPU/GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the trained CC_model
model_path = "CC_net.pt"
model = CC_model(num_classes1=14)
# Load model weights
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict, strict=False)
model.to(device)
model.eval()
# Clothing1M Class Labels
class_labels = [
"T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie",
"Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress",
"Vest", "Underwear"
]
# **Predefined Default Images**
default_images = {
"Shawl": "tshirt.jpg",
"Jacket": "jacket.jpg",
"Sweater": "sweater.webp",
"Vest": "dress.jpg"
}
# **Image Preprocessing Function**
def preprocess_image(image):
"""Applies necessary transformations to the input image."""
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return transform(image).unsqueeze(0).to(device)
# **Classification Function**
def classify_image(selected_default, uploaded_image):
"""Processes either a default or uploaded image and returns the predicted clothing category."""
print("\n[INFO] Image selection process started.")
try:
# Use the uploaded image if provided; otherwise, use the selected default image
if uploaded_image is not None:
print("[INFO] Using uploaded image.")
image = Image.fromarray(uploaded_image) # Ensure conversion to PIL format
else:
print(f"[INFO] Using default image: {selected_default}")
image_path = default_images[selected_default]
image = Image.open(image_path) # Load the selected default image
image = preprocess_image(image) # Apply transformations
print("[INFO] Image transformed and moved to device.")
with torch.no_grad():
output = model(image)
# Ensure output is a tensor (handle tuple case)
if isinstance(output, tuple):
output = output[1] # Extract the actual output tensor
print(f"[DEBUG] Model output shape: {output.shape}")
print(f"[DEBUG] Model output values: {output}")
if output.shape[1] != 14:
return f"[ERROR] Model output mismatch! Expected 14 but got {output.shape[1]}."
# Convert logits to probabilities
probabilities = F.softmax(output, dim=1)
print(f"[DEBUG] Softmax probabilities: {probabilities}")
# Get predicted class index
predicted_class = torch.argmax(probabilities, dim=1).item()
print(f"[INFO] Predicted class index: {predicted_class} (Class: {class_labels[predicted_class]})")
# Validate and return the prediction
if 0 <= predicted_class < len(class_labels):
predicted_label = class_labels[predicted_class]
confidence = probabilities[0][predicted_class].item() * 100
return f"Predicted Class: {predicted_label} (Confidence: {confidence:.2f}%)"
else:
return "[ERROR] Model returned an invalid class index."
except Exception as e:
print(f"[ERROR] Exception during classification: {e}")
return "Error in classification. Check console for details."
# **Gradio Interface**
with gr.Blocks() as interface:
gr.Markdown("# Clothing1M Image Classifier")
gr.Markdown("Upload a clothing image or select from the predefined images below.")
# Default Image Selection
default_selector = gr.Radio(
choices=list(default_images.keys()),
label="Select a Default Image",
value="T-Shirt"
)
# File Upload Option
image_upload = gr.Image(type="numpy", label="Or Upload Your Own Image")
# Output Text
output_text = gr.Textbox(label="Classification Result")
# Classify Button
classify_button = gr.Button("Classify Image")
# Define Action
classify_button.click(
fn=classify_image,
inputs=[default_selector, image_upload],
outputs=output_text
)
# **Run the Interface**
if __name__ == "__main__":
print("[INFO] Launching Gradio interface...")
interface.launch()
|