File size: 2,796 Bytes
9197021 49d337b 9197021 048aeab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import subprocess
import sys
import torch
print("β
PyTorch version:", torch.__version__)
# Ensure torch is installed before importing
try:
import torch
except ModuleNotFoundError:
print("π¨ Torch not found! Installing...")
subprocess.run([sys.executable, "-m", "pip", "install", "torch", "torchvision", "torchaudio"], check=True)
import torch # Try again after installation
# β
Define class labels (from Clothing1M)
class_labels = [
"T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie",
"Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress",
"Vest", "Underwear"
]
# β
Function to load the model
def create_model_selfsup(net='resnet50', num_class=14, checkpoint_path='/content/ckpt_clothing_resnet50.pth'):
"""Loads a self-supervised pretrained model for Clothing1M classification"""
print(f"π Loading model from: {checkpoint_path}")
# Load the checkpoint safely
checkpoint = torch.load(checkpoint_path, map_location="cuda" if torch.cuda.is_available() else "cpu", weights_only=False)
# Remove 'module.' prefix if using DataParallel
state_dict = {k.replace('module.', ''): v for k, v in checkpoint['model'].items()}
# Initialize and load model
model = SupCEResNet(net, num_classes=num_class, pool=True)
model.load_state_dict(state_dict, strict=False)
# Move model to GPU if available
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
model.eval() # Set model to evaluation mode
print("β
Model loaded successfully!")
return model
# β
Load the model once
model = create_model_selfsup()
# β
Define image preprocessing function
def preprocess_image(image):
"""Transforms input image for the model"""
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return transform(image).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
# β
Define inference function
def predict_clothing(image):
"""Runs inference on an uploaded image"""
image = Image.fromarray(image) # Convert numpy array to PIL Image
image = preprocess_image(image) # Preprocess image
with torch.no_grad():
output = model(image)
predicted_class = torch.argmax(output, dim=1).item() # Get class index
return class_labels[predicted_class] # Return class name
# β
Create Gradio Interface
gr.Interface(
fn=predict_clothing,
inputs=gr.Image(type="numpy"),
outputs=gr.Textbox(label="Predicted Clothing Type"),
title="Clothing1M Classification",
description="Upload an image to classify clothing into one of 14 categories."
).launch()
|