import torch import torchvision from torch import nn from torchvision import transforms from torchvision.transforms import InterpolationMode from PIL import Image import gradio as gr import os import matplotlib.pyplot as plt import seaborn as sns os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # Device configuration device = "cuda" if torch.cuda.is_available() else "cpu" # Assuming 'class_names' is already defined in your script class_names = [line.strip() for line in open("classes.txt")] # Load the model model = torchvision.models.vit_b_16(weights=None) # Initialize the model architecture model.heads = nn.Linear(in_features=768, out_features=len(class_names)) # Adjust the classifier head checkpoint = torch.load('08_pretrained_vit_feature_extractor_pizza_steak_sushi.pth', map_location=torch.device('cpu')) model.load_state_dict(checkpoint, strict=False) model = model.to(device) model.eval() # Define transformations transform = transforms.Compose([ transforms.Resize(256, interpolation=InterpolationMode.BILINEAR), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Prediction function def predict(image): img = Image.fromarray(image) transformed_image = transform(img).unsqueeze(dim=0).to(device) with torch.inference_mode(): target_image_pred = model(transformed_image) target_image_pred_probs = torch.softmax(target_image_pred, dim=1) top_probs, top_indices = torch.topk(target_image_pred_probs, k=5) top_probs = top_probs.squeeze().cpu().numpy() top_indices = top_indices.squeeze().cpu().numpy() top_classes = [class_names[i] for i in top_indices] # Plotting the probabilities as a bar chart fig, ax = plt.subplots(figsize=(10, 6)) sns.barplot(x=top_probs, y=top_classes, palette="viridis", ax=ax) ax.set_xlabel('Probability') ax.set_ylabel('Class') ax.set_title('Top 5 Predictions') ax.set_xlim(0, 1) for i in ax.patches: ax.text(i.get_width() + 0.02, i.get_y() + 0.55, f'{i.get_width():.2f}', ha='center', va='center', fontsize=10, color='black') sns.despine(left=True, bottom=True) plt.tight_layout() return top_classes[0], fig # Create Gradio interface iface = gr.Interface( fn=predict, inputs=gr.Image(type="numpy"), outputs=[gr.Textbox(label="Top Prediction"), gr.Plot()], # Textbox for top prediction and Plot for the bar chart ) # Launch the Gradio app iface.launch()