Harshithtd commited on
Commit
77b576a
·
verified ·
1 Parent(s): 502a2bf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torch import nn
4
+ from torchvision import transforms
5
+ from torchvision.transforms import InterpolationMode
6
+ from PIL import Image
7
+ import gradio as gr
8
+ import os
9
+ import matplotlib.pyplot as plt
10
+ import seaborn as sns
11
+
12
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
13
+
14
+ # Device configuration
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ # Assuming 'class_names' is already defined in your script
18
+ class_names = [line.strip() for line in open("classes.txt")]
19
+
20
+ # Load the model
21
+ model = torchvision.models.vit_b_16(weights=None) # Initialize the model architecture
22
+ model.heads = nn.Linear(in_features=768, out_features=len(class_names)) # Adjust the classifier head
23
+ checkpoint = torch.load('08_pretrained_vit_feature_extractor_pizza_steak_sushi.pth')
24
+ model.load_state_dict(checkpoint, strict=False)
25
+ model = model.to(device)
26
+ model.eval()
27
+
28
+ # Define transformations
29
+ transform = transforms.Compose([
30
+ transforms.Resize(256, interpolation=InterpolationMode.BILINEAR),
31
+ transforms.CenterCrop(224),
32
+ transforms.ToTensor(),
33
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
34
+ ])
35
+
36
+ # Prediction function
37
+ def predict(image):
38
+ img = Image.fromarray(image)
39
+ transformed_image = transform(img).unsqueeze(dim=0).to(device)
40
+
41
+ with torch.inference_mode():
42
+ target_image_pred = model(transformed_image)
43
+
44
+ target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
45
+ top_probs, top_indices = torch.topk(target_image_pred_probs, k=5)
46
+ top_probs = top_probs.squeeze().cpu().numpy()
47
+ top_indices = top_indices.squeeze().cpu().numpy()
48
+
49
+ top_classes = [class_names[i] for i in top_indices]
50
+
51
+ # Plotting the probabilities as a bar chart
52
+ fig, ax = plt.subplots(figsize=(10, 6))
53
+ sns.barplot(x=top_probs, y=top_classes, palette="viridis", ax=ax)
54
+ ax.set_xlabel('Probability')
55
+ ax.set_ylabel('Class')
56
+ ax.set_title('Top 5 Predictions')
57
+ ax.set_xlim(0, 1)
58
+ for i in ax.patches:
59
+ ax.text(i.get_width() + 0.02, i.get_y() + 0.55, f'{i.get_width():.2f}',
60
+ ha='center', va='center', fontsize=10, color='black')
61
+ sns.despine(left=True, bottom=True)
62
+
63
+ plt.tight_layout()
64
+
65
+ return top_classes[0], fig
66
+
67
+ # Create Gradio interface
68
+ iface = gr.Interface(
69
+ fn=predict,
70
+ inputs=gr.Image(type="numpy"),
71
+ outputs=[gr.Textbox(label="Top Prediction"), gr.Plot()], # Textbox for top prediction and Plot for the bar chart
72
+ examples=[r"C:\Users\Asus\Desktop\download (1).jpg", r"C:\Users\Asus\Desktop\download (3).jpg"] # Optional: Add paths to example images
73
+ )
74
+
75
+ # Launch the Gradio app
76
+ iface.launch()