File size: 3,858 Bytes
4a80f4f c7f845c 4a80f4f b95f530 c7f845c 385c8a3 b95f530 574ea78 c7f845c 57b4f24 491fff3 b95f530 4a80f4f b95f530 c7f845c 4a80f4f c7f845c 57b4f24 c7f845c 57b4f24 4a80f4f 57b4f24 4a80f4f 57b4f24 4a80f4f 57b4f24 4a80f4f 57b4f24 4a80f4f b95f530 c7f845c 57b4f24 491fff3 57b4f24 491fff3 57b4f24 c7f845c 491fff3 c7f845c 57b4f24 c7f845c 57b4f24 c7f845c 57b4f24 c7f845c 57b4f24 c7f845c 4a80f4f b95f530 4a80f4f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import gradio as gr
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import os
from ResNet_for_CC import CC_model # Import the model
# Set device (CPU/GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the trained CC_model
model_path = "CC_net.pt"
model = CC_model(num_classes1=14)
# Load model weights
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict, strict=False)
model.to(device)
model.eval()
# Clothing1M Class Labels
class_labels = [
"T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie",
"Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress",
"Vest", "Underwear"
]
# **Predefined Default Images**
default_images = {
"Shawl": "shawlOG.webp",
"Jacket": "jacket.jpg",
"Sweater": "sweater.webp",
"Vest": "dress.jpg"
}
# Convert to gallery format (list of (image_path, caption) tuples)
default_images_gallery = [(path, label) for label, path in default_images.items()]
# **Image Preprocessing Function**
def preprocess_image(image):
"""Applies necessary transformations to the input image."""
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return transform(image).unsqueeze(0).to(device)
# **Classification Function**
def classify_image(selected_default, uploaded_image):
"""Processes either a default or uploaded image and returns the predicted clothing category."""
try:
# Use the uploaded image if provided; otherwise, use the selected default image
if uploaded_image is not None:
image = Image.fromarray(uploaded_image)
else:
image_path = default_images[selected_default]
image = Image.open(image_path)
image = preprocess_image(image)
with torch.no_grad():
output = model(image)
if isinstance(output, tuple):
output = output[1]
probabilities = F.softmax(output, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
if 0 <= predicted_class < len(class_labels):
predicted_label = class_labels[predicted_class]
confidence = probabilities[0][predicted_class].item() * 100
return f"Predicted Class: {predicted_label} (Confidence: {confidence:.2f}%)"
else:
return "[ERROR] Model returned an invalid class index."
except Exception as e:
return f"Error in classification: {e}"
# **Gradio Interface**
with gr.Blocks() as interface:
gr.Markdown("# Clothing1M Image Classifier")
gr.Markdown("Upload a clothing image or select from the predefined images below.")
# Gallery to display default images
gallery = gr.Gallery(
value=default_images_gallery, # Provide list of (image, caption) tuples
label="Default Images",
elem_id="default_gallery"
)
# Default Image Selection
default_selector = gr.Dropdown(
choices=list(default_images.keys()),
label="Select a Default Image",
value="Shawl"
)
# File Upload Option
image_upload = gr.Image(type="numpy", label="Or Upload Your Own Image")
# Output Text
output_text = gr.Textbox(label="Classification Result")
# Classify Button
classify_button = gr.Button("Classify Image")
# Define Action
classify_button.click(
fn=classify_image,
inputs=[default_selector, image_upload],
outputs=output_text
)
# **Run the Interface**
if __name__ == "__main__":
print("[INFO] Launching Gradio interface...")
interface.launch()
|