Moditha24 commited on
Commit
048aeab
Β·
verified Β·
1 Parent(s): a673bf7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+ from resnet import SupCEResNet # Ensure the correct import path
6
+
7
+ # βœ… Define class labels (from Clothing1M)
8
+ class_labels = [
9
+ "T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie",
10
+ "Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress",
11
+ "Vest", "Underwear"
12
+ ]
13
+
14
+ # βœ… Function to load the model
15
+ def create_model_selfsup(net='resnet50', num_class=14, checkpoint_path='/content/ckpt_clothing_resnet50.pth'):
16
+ """Loads a self-supervised pretrained model for Clothing1M classification"""
17
+ print(f"πŸ”„ Loading model from: {checkpoint_path}")
18
+
19
+ # Load the checkpoint safely
20
+ checkpoint = torch.load(checkpoint_path, map_location="cuda" if torch.cuda.is_available() else "cpu", weights_only=False)
21
+
22
+ # Remove 'module.' prefix if using DataParallel
23
+ state_dict = {k.replace('module.', ''): v for k, v in checkpoint['model'].items()}
24
+
25
+ # Initialize and load model
26
+ model = SupCEResNet(net, num_classes=num_class, pool=True)
27
+ model.load_state_dict(state_dict, strict=False)
28
+
29
+ # Move model to GPU if available
30
+ model = model.to("cuda" if torch.cuda.is_available() else "cpu")
31
+ model.eval() # Set model to evaluation mode
32
+
33
+ print("βœ… Model loaded successfully!")
34
+ return model
35
+
36
+ # βœ… Load the model once
37
+ model = create_model_selfsup()
38
+
39
+ # βœ… Define image preprocessing function
40
+ def preprocess_image(image):
41
+ """Transforms input image for the model"""
42
+ transform = transforms.Compose([
43
+ transforms.Resize(256),
44
+ transforms.CenterCrop(224),
45
+ transforms.ToTensor(),
46
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
47
+ ])
48
+ return transform(image).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
49
+
50
+ # βœ… Define inference function
51
+ def predict_clothing(image):
52
+ """Runs inference on an uploaded image"""
53
+ image = Image.fromarray(image) # Convert numpy array to PIL Image
54
+ image = preprocess_image(image) # Preprocess image
55
+
56
+ with torch.no_grad():
57
+ output = model(image)
58
+ predicted_class = torch.argmax(output, dim=1).item() # Get class index
59
+
60
+ return class_labels[predicted_class] # Return class name
61
+
62
+ # βœ… Create Gradio Interface
63
+ gr.Interface(
64
+ fn=predict_clothing,
65
+ inputs=gr.Image(type="numpy"),
66
+ outputs=gr.Textbox(label="Predicted Clothing Type"),
67
+ title="Clothing1M Classification",
68
+ description="Upload an image to classify clothing into one of 14 categories."
69
+ ).launch()