Harshithtd's picture
Update app.py
56b6a84 verified
import torch
import torchvision
from torch import nn
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from PIL import Image
import gradio as gr
import os
import matplotlib.pyplot as plt
import seaborn as sns
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# Device configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
# Assuming 'class_names' is already defined in your script
class_names = [line.strip() for line in open("classes.txt")]
# Load the model
model = torchvision.models.vit_b_16(weights=None) # Initialize the model architecture
model.heads = nn.Linear(in_features=768, out_features=len(class_names)) # Adjust the classifier head
checkpoint = torch.load('08_pretrained_vit_feature_extractor_pizza_steak_sushi.pth', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint, strict=False)
model = model.to(device)
model.eval()
# Define transformations
transform = transforms.Compose([
transforms.Resize(256, interpolation=InterpolationMode.BILINEAR),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Prediction function
def predict(image):
img = Image.fromarray(image)
transformed_image = transform(img).unsqueeze(dim=0).to(device)
with torch.inference_mode():
target_image_pred = model(transformed_image)
target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
top_probs, top_indices = torch.topk(target_image_pred_probs, k=5)
top_probs = top_probs.squeeze().cpu().numpy()
top_indices = top_indices.squeeze().cpu().numpy()
top_classes = [class_names[i] for i in top_indices]
# Plotting the probabilities as a bar chart
fig, ax = plt.subplots(figsize=(10, 6))
sns.barplot(x=top_probs, y=top_classes, palette="viridis", ax=ax)
ax.set_xlabel('Probability')
ax.set_ylabel('Class')
ax.set_title('Top 5 Predictions')
ax.set_xlim(0, 1)
for i in ax.patches:
ax.text(i.get_width() + 0.02, i.get_y() + 0.55, f'{i.get_width():.2f}',
ha='center', va='center', fontsize=10, color='black')
sns.despine(left=True, bottom=True)
plt.tight_layout()
return top_classes[0], fig
# Create Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy"),
outputs=[gr.Textbox(label="Top Prediction"), gr.Plot()], # Textbox for top prediction and Plot for the bar chart
)
# Launch the Gradio app
iface.launch()