Gallai commited on
Commit
0b3aae0
·
verified ·
1 Parent(s): 33f2fd7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import gradio as gr
4
+ from pathlib import Path
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+ from huggingface_hub import hf_hub_download
8
+ from ResNet_for_CC import CC_model
9
+
10
+ # Define the Clothing1M class labels
11
+ CLOTHING1M_CLASSES = [
12
+ "T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater",
13
+ "Hoodie", "Windbreaker", "Jacket", "Downcoat",
14
+ "Suit", "Shawl", "Dress", "Vest", "Underwear"
15
+ ]
16
+
17
+ # Initialize the model
18
+ model = CC_model()
19
+ model_path = hf_hub_download(repo_id="mohamdlog/CC", filename="CC_net.pt")
20
+ model.load_state_dict(torch.load(model_path, map_location='cpu'))
21
+ model.eval()
22
+
23
+ # Define preprocessing pipeline
24
+ def preprocess_image(image):
25
+ if isinstance(image, np.ndarray):
26
+ image = Image.fromarray(image)
27
+ transform = transforms.Compose([
28
+ transforms.Resize((224, 224)),
29
+ transforms.ToTensor(),
30
+ ])
31
+ return transform(image).unsqueeze(0)
32
+
33
+ # Define classification function
34
+ def classify_image(image):
35
+ input_tensor = preprocess_image(image)
36
+ with torch.no_grad():
37
+ output = model(input_tensor)
38
+
39
+ # Get predicted class and confidence
40
+ probabilities = torch.nn.functional.softmax(output, dim=1)
41
+ predicted_class_idx = output.argmax(dim=1).item()
42
+ predicted_class = CLOTHING1M_CLASSES[predicted_class_idx]
43
+ confidence = probabilities[0][predicted_class_idx].item()
44
+
45
+ return f"Category: {predicted_class}\nConfidence: {confidence:.2f}"
46
+
47
+ # Create Gradio interface
48
+ interface = gr.Interface(
49
+ fn=classify_image,
50
+ inputs=gr.Image(label="Uploaded Image"),
51
+ outputs=gr.Text(label="Predicted Clothing"),
52
+ title="Clothing Category Classifier",
53
+ description = """
54
+ **Upload an image of clothing, and the model will predict its category.**
55
+ Try using an image that doesn't belong to any of the available categories, and see how the result differs!
56
+ **Categories:**
57
+ | T-Shirt | Shirt | Knitwear | Chiffon | Sweater | Hoodie | Windbreaker |
58
+ | Jacket | Downcoat | Suit | Shawl | Dress | Vest | Underwear |
59
+ """,
60
+ examples=[[str(file)] for file in Path("examples").glob("*")],
61
+ flagging_mode="never",
62
+ theme="soft"
63
+ )
64
+
65
+ # Launch the interface
66
+ if __name__ == "__main__":
67
+ interface.launch()