import gradio as gr import torch import torch.nn.functional as F import torchvision.transforms as transforms from PIL import Image import os from ResNet_for_CC import CC_model # Import the model # Set device (CPU/GPU) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the trained CC_model model_path = "CC_net.pt" model = CC_model(num_classes1=14) # Load model weights state_dict = torch.load(model_path, map_location=device) model.load_state_dict(state_dict, strict=False) model.to(device) model.eval() # Clothing1M Class Labels class_labels = [ "T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie", "Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress", "Vest", "Underwear" ] # **Predefined Default Images** default_images = { "Shawl": "shawlOG.webp", "Jacket": "jacket.jpg", "Sweater": "sweater.webp", "Vest": "dress.jpg" } # Convert to gallery format (list of (image_path, caption) tuples) default_images_gallery = [(path, label) for label, path in default_images.items()] # **Image Preprocessing Function** def preprocess_image(image): """Applies necessary transformations to the input image.""" transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return transform(image).unsqueeze(0).to(device) # **Classification Function** def classify_image(selected_default, uploaded_image): """Processes either a default or uploaded image and returns the predicted clothing category.""" try: # Use the uploaded image if provided; otherwise, use the selected default image if uploaded_image is not None: image = Image.fromarray(uploaded_image) else: image_path = default_images[selected_default] image = Image.open(image_path) image = preprocess_image(image) with torch.no_grad(): output = model(image) if isinstance(output, tuple): output = output[1] probabilities = F.softmax(output, dim=1) predicted_class = torch.argmax(probabilities, dim=1).item() if 0 <= predicted_class < len(class_labels): predicted_label = class_labels[predicted_class] confidence = probabilities[0][predicted_class].item() * 100 return f"Predicted Class: {predicted_label} (Confidence: {confidence:.2f}%)" else: return "[ERROR] Model returned an invalid class index." except Exception as e: return f"Error in classification: {e}" # **Gradio Interface** with gr.Blocks() as interface: gr.Markdown("# Clothing1M Image Classifier") gr.Markdown("Upload a clothing image or select from the predefined images below.") # Gallery to display default images gallery = gr.Gallery( value=default_images_gallery, # Provide list of (image, caption) tuples label="Default Images", elem_id="default_gallery" ) # Default Image Selection default_selector = gr.Dropdown( choices=list(default_images.keys()), label="Select a Default Image", value="Shawl" ) # File Upload Option image_upload = gr.Image(type="numpy", label="Or Upload Your Own Image") # Output Text output_text = gr.Textbox(label="Classification Result") # Classify Button classify_button = gr.Button("Classify Image") # Define Action classify_button.click( fn=classify_image, inputs=[default_selector, image_upload], outputs=output_text ) # **Run the Interface** if __name__ == "__main__": print("[INFO] Launching Gradio interface...") interface.launch()