|
import subprocess |
|
import sys |
|
|
|
import torch |
|
print("β
PyTorch version:", torch.__version__) |
|
|
|
|
|
|
|
try: |
|
import torch |
|
except ModuleNotFoundError: |
|
print("π¨ Torch not found! Installing...") |
|
subprocess.run([sys.executable, "-m", "pip", "install", "torch", "torchvision", "torchaudio"], check=True) |
|
import torch |
|
|
|
|
|
|
|
class_labels = [ |
|
"T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie", |
|
"Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress", |
|
"Vest", "Underwear" |
|
] |
|
|
|
|
|
def create_model_selfsup(net='resnet50', num_class=14, checkpoint_path='/content/ckpt_clothing_resnet50.pth'): |
|
"""Loads a self-supervised pretrained model for Clothing1M classification""" |
|
print(f"π Loading model from: {checkpoint_path}") |
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location="cuda" if torch.cuda.is_available() else "cpu", weights_only=False) |
|
|
|
|
|
state_dict = {k.replace('module.', ''): v for k, v in checkpoint['model'].items()} |
|
|
|
|
|
model = SupCEResNet(net, num_classes=num_class, pool=True) |
|
model.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
model = model.to("cuda" if torch.cuda.is_available() else "cpu") |
|
model.eval() |
|
|
|
print("β
Model loaded successfully!") |
|
return model |
|
|
|
|
|
model = create_model_selfsup() |
|
|
|
|
|
def preprocess_image(image): |
|
"""Transforms input image for the model""" |
|
transform = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
return transform(image).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
def predict_clothing(image): |
|
"""Runs inference on an uploaded image""" |
|
image = Image.fromarray(image) |
|
image = preprocess_image(image) |
|
|
|
with torch.no_grad(): |
|
output = model(image) |
|
predicted_class = torch.argmax(output, dim=1).item() |
|
|
|
return class_labels[predicted_class] |
|
|
|
|
|
gr.Interface( |
|
fn=predict_clothing, |
|
inputs=gr.Image(type="numpy"), |
|
outputs=gr.Textbox(label="Predicted Clothing Type"), |
|
title="Clothing1M Classification", |
|
description="Upload an image to classify clothing into one of 14 categories." |
|
).launch() |
|
|