File size: 2,546 Bytes
77b576a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1844bf
77b576a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56b6a84
77b576a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torchvision
from torch import nn
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from PIL import Image
import gradio as gr
import os
import matplotlib.pyplot as plt
import seaborn as sns

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

# Device configuration
device = "cuda" if torch.cuda.is_available() else "cpu"

# Assuming 'class_names' is already defined in your script
class_names = [line.strip() for line in open("classes.txt")]

# Load the model
model = torchvision.models.vit_b_16(weights=None)  # Initialize the model architecture
model.heads = nn.Linear(in_features=768, out_features=len(class_names))  # Adjust the classifier head
checkpoint = torch.load('08_pretrained_vit_feature_extractor_pizza_steak_sushi.pth', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint, strict=False)
model = model.to(device)
model.eval()

# Define transformations
transform = transforms.Compose([
    transforms.Resize(256, interpolation=InterpolationMode.BILINEAR),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Prediction function
def predict(image):
    img = Image.fromarray(image)
    transformed_image = transform(img).unsqueeze(dim=0).to(device)
    
    with torch.inference_mode():
        target_image_pred = model(transformed_image)
    
    target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
    top_probs, top_indices = torch.topk(target_image_pred_probs, k=5)
    top_probs = top_probs.squeeze().cpu().numpy()
    top_indices = top_indices.squeeze().cpu().numpy()
    
    top_classes = [class_names[i] for i in top_indices]

    # Plotting the probabilities as a bar chart
    fig, ax = plt.subplots(figsize=(10, 6))
    sns.barplot(x=top_probs, y=top_classes, palette="viridis", ax=ax)
    ax.set_xlabel('Probability')
    ax.set_ylabel('Class')
    ax.set_title('Top 5 Predictions')
    ax.set_xlim(0, 1)
    for i in ax.patches:
        ax.text(i.get_width() + 0.02, i.get_y() + 0.55, f'{i.get_width():.2f}', 
                ha='center', va='center', fontsize=10, color='black')
    sns.despine(left=True, bottom=True)
    
    plt.tight_layout()
    
    return top_classes[0], fig

# Create Gradio interface
iface = gr.Interface(
    fn=predict, 
    inputs=gr.Image(type="numpy"), 
    outputs=[gr.Textbox(label="Top Prediction"), gr.Plot()],  # Textbox for top prediction and Plot for the bar chart

)

# Launch the Gradio app
iface.launch()