cifar10 / app.py
swapniel99's picture
Update app.py
2af7de6
raw
history blame
3.68 kB
import torch
import pandas as pd
import numpy as np
import gradio as gr
from PIL import Image
from torch.nn import functional as F
from collections import OrderedDict
from torchvision import transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from models.custom_resnet import Model
from datasets import CIFAR10
cifar10 = CIFAR10(normalize=False, shuffle=False, augment=False)
_ = cifar10.test_data
missed_df = pd.read_csv('S12_incorrect.csv')
missed_df['ground_truths'] = missed_df['ground_truths'].map(cifar10.classes)
missed_df['predicted_vals'] = missed_df['predicted_vals'].map(cifar10.classes)
missed_df = missed_df.sample(frac=1)
model = Model(cifar10)
model.load_state_dict(torch.load('S12_model.pth', map_location='cpu'))
model.eval()
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.25, 0.25, 0.25])
])
inv_transform = transforms.Normalize(mean=[-2, -2, -2], std=[4, 4, 4])
def image_classifier(input_image, top_classes=3, show_cam=True, target_layer=-1, transparency=0.5):
input = transform(input_image).unsqueeze(0)
output = model(input)
output = F.softmax(output.flatten(), dim=-1)
confidences = [(cifar10.classes[i], float(output[i])) for i in range(10)]
confidences.sort(key=lambda x: x[1], reverse=True)
confidences = OrderedDict(confidences[:top_classes])
label = torch.argmax(output).item()
target_layer = [model.network[4 + target_layer]]
grad_cam = GradCAM(model=model, target_layers=target_layer, use_cuda=False)
targets = [ClassifierOutputTarget(label)]
grayscale_cam = grad_cam(input_tensor=input, targets=targets)
grayscale_cam = grayscale_cam[0, :]
output_image = show_cam_on_image(input_image/255, grayscale_cam, use_rgb=True, image_weight=transparency)
return output_image if show_cam else input_image, confidences
demo1 = gr.Interface(
fn=image_classifier,
inputs=[
gr.Image(shape=(32, 32), label="Input Image", value='examples/cat.jpg'),
gr.Slider(1, 10, value = 3, step=1, label="Number of Top Classes"),
gr.Checkbox(label="Show GradCAM?", value=True),
gr.Slider(-4, -1, value = -2, step=1, label="Which Layer?"),
gr.Slider(0, 1, value = 0.7, label="Transparency", step=0.1)
],
outputs=[gr.Image(shape=(32, 32), label="Output Image"),
gr.Label(label='Top Classes')],
examples=[[f'examples/{k}.jpg'] for k in cifar10.classes.values()]
)
def show_incorrect(num_examples=10, show_cam=True, target_layer=-1, transparency=0.5):
result = list()
for index, row in missed_df.head(num_examples).iterrows():
image = np.asarray(Image.open(f'missed_examples/{index}.jpg'))
output_image, confidence = image_classifier(image, show_cam=show_cam, target_layer=target_layer, transparency=transparency)
predicted = list(confidence)[0]
result.append((output_image, f"{row['ground_truths']} / {predicted}"))
return result
demo2 = gr.Interface(
fn=show_incorrect,
inputs=[
gr.Number(value=20, minimum=1, maximum=len(missed_df), label="No. of missclassified Examples", precision=0),
gr.Checkbox(label="Show GradCAM?", value=True),
gr.Slider(-4, -1, value = -2, step=1, label="Which Layer?"),
gr.Slider(0, 1, value = 0.7, label="Transparency", step=0.1),
],
outputs=[gr.Gallery(label="Missclassified Images (Truth / Predicted)", columns=4)]
)
demo = gr.TabbedInterface([demo1, demo2], ["Examples", "Misclassified Examples"])
demo.launch()