Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
import torch.nn.functional as F | |
import gradio as gr | |
import torch | |
import random | |
from collections import OrderedDict | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
import numpy as np | |
from PIL import Image | |
from torchvision import transforms | |
dropout_value = 0.1 | |
class ResBlock(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(ResBlock,self).__init__() | |
self.res_block = nn.Sequential( | |
nn.Conv2d(in_channels=in_channels, out_channels = out_channels, kernel_size=3, stride =1 , padding =1), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(), | |
nn.Conv2d(in_channels=out_channels, out_channels = out_channels, kernel_size=3, stride =1 , padding =1), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(), | |
) | |
def forward (self, x): | |
x = self.res_block(x) | |
return x | |
class LayerBlock(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(LayerBlock,self).__init__() | |
self.layer_block = nn.Sequential( | |
nn.Conv2d(in_channels=in_channels, out_channels = out_channels, kernel_size=3, stride =1 , padding =1), | |
nn.MaxPool2d(kernel_size=2,stride=2), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(), | |
) | |
def forward (self, x): | |
x = self.layer_block(x) | |
return x | |
class custom_resnet_s10(nn.Module): | |
def __init__(self, num_classes=10): | |
super(custom_resnet_s10,self).__init__() | |
self.PrepLayer = nn.Sequential( | |
nn.Conv2d(in_channels = 3, out_channels=64, kernel_size = 3, stride = 1, padding =1), | |
nn.BatchNorm2d(64), | |
nn.ReLU(), | |
) | |
self.Layer1 = LayerBlock(in_channels = 64, out_channels=128) | |
self.resblock1 = ResBlock(in_channels =128, out_channels=128) | |
self.Layer2 = LayerBlock(in_channels = 128, out_channels=256) | |
self.resblock2 = ResBlock(in_channels =256, out_channels=256) | |
self.Layer3 = LayerBlock(in_channels = 256, out_channels=512) | |
self.resblock3 = ResBlock(in_channels =512, out_channels=512) | |
self.max_pool4 = nn.MaxPool2d(kernel_size=4, stride=4) # 512,512, 4/4 = 512,512,1 | |
self.fc = nn.Linear(512,num_classes) | |
def forward(self,x): | |
x = self.PrepLayer(x) | |
x = self.Layer1(x) | |
resl1 = self.resblock1(x) | |
x = x+resl1 | |
x = self.Layer2(x) | |
resl2 = self.resblock2(x) | |
x = x+resl2 | |
x = self.Layer3(x) | |
resl3 = self.resblock3(x) | |
x = x+resl3 | |
x = self.max_pool4(x) | |
x = x.view(x.size(0),-1) | |
x = self.fc(x) | |
return x | |
def get_device(): | |
if torch.cuda.is_available(): | |
device = "cuda" | |
elif torch.backends.mps.is_available(): | |
device = "mps" | |
else: | |
device = "cpu" | |
print("Device Selected:", device) | |
return device | |
DEVICE = get_device() | |
# Load the list of tensors from the file | |
loaded_misclassified_image_list = torch.load('misclassified_images_list.pt') | |
# Instantiate the model (make sure it has the same architecture) | |
loaded_model = custom_resnet_s10() | |
loaded_model = loaded_model.to(DEVICE) | |
# Load the saved state dictionary | |
loaded_model.load_state_dict(torch.load('model.pth', map_location=DEVICE), strict=False) | |
# Put the loaded model in evaluation mode | |
loaded_model.eval() | |
classes = ['plane', 'car', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck'] | |
mean = (0.49139968, 0.48215827, 0.44653124) | |
std = (0.24703233, 0.24348505, 0.26158768) | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=mean, std=std) | |
]) | |
dict_layer = {'layer3': loaded_model.resblock2.res_block[-1], | |
'layer4': loaded_model.resblock3.res_block[-1]} | |
def view_gradcam_images(choice_gradcam): | |
if choice_gradcam == "Yes (View Existing Images)": | |
return gr.update(label ="Number of GradCAM Images to view", visible=True, interactive = True), \ | |
gr.update(visible=True), \ | |
gr.update(visible=True), gr.update(visible=True), \ | |
gr.update(visible=False) # Gallery not shown as yet | |
else: | |
#TODO: to be completed | |
return gr.update(visible=False), gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False) | |
def process_gradcam_images(num_images,layer,opacity,image_list=None): | |
if not image_list: | |
selected_data = random.sample(loaded_misclassified_image_list, min(num_images,len(loaded_misclassified_image_list))) | |
else: | |
selected_data = [image_list] | |
layer_model = dict_layer.get(layer) | |
cam = GradCAM(model=loaded_model, target_layers = [layer_model], use_cuda = False) | |
grad_images = [] | |
inv_normalize = transforms.Normalize( | |
mean=[-0.50/0.2197, -0.50/0.1858, -0.50/0.1569], # mean_ds = [0.2197, 0.1858, 0.1569] | |
std=[1/0.1810, 1/0.1635, 1/0.1511] # std_dev_ds =[0.1810, 0.1635, 0.1511] | |
) | |
for i, (img, pred, correct) in enumerate(selected_data): | |
input_tensor = img.unsqueeze(0) | |
targets = [ClassifierOutputTarget(pred)] | |
grayscale_cam = cam(input_tensor=input_tensor, targets=targets) | |
grayscale_cam = grayscale_cam[0, :] | |
# Get back the original image | |
img = input_tensor.squeeze(0).to('cpu') | |
img = inv_normalize(img) | |
rgb_img = np.transpose(img, (1, 2, 0)) | |
rgb_img = torch.clamp(rgb_img, max = 1) | |
rgb_img = rgb_img.numpy() | |
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=opacity) | |
if not image_list: | |
grad_images.append(((visualization),f'Pred: {classes[pred.cpu()]} | Truth :{classes[correct.cpu()]}')) | |
else: | |
grad_images.append(((visualization),f'Prediction: {classes[pred.cpu()]}')) | |
print(str(num_images) + "**" + str(layer) + "**" + str(opacity)) | |
return grad_images, gr.update(visible=True) | |
def process_misclassified_images(num_images): | |
selected_data = random.sample(loaded_misclassified_image_list, min(num_images,len(loaded_misclassified_image_list))) | |
misclassified_images = [] | |
for i, (img, pred, correct) in enumerate(selected_data): | |
img, pred, target = img.cpu().numpy().astype(dtype=np.float32), pred.cpu(), correct.cpu() | |
for j in range(img.shape[0]): | |
img[j] = (img[j] * std[j]) + mean[j] | |
img = np.transpose(img, (1, 2, 0)) | |
img = Image.fromarray((img * 255).astype(np.uint8)) | |
misclassified_images.append(((img),f'Pred: {classes[pred]} | Truth :{classes[correct]}')) | |
return misclassified_images, gr.update(visible=True) | |
def view_misclassified_images(choice_misclassified): | |
if choice_misclassified == "Yes": | |
return gr.update(label ="Number of Misclassified Images to view", visible=True, interactive = True),gr.update(visible=True),gr.update(visible=False) | |
else: | |
return gr.update(visible=False),gr.update(visible=False),gr.update(visible=False) | |
def classify_image(image, num_classes=3, grad_cam_choice = False, layer = None, opacity = 0.8 ): | |
# transforming image and getting prediction from model | |
transformed_image = transform(image) | |
image_tensor = transformed_image.to(DEVICE).unsqueeze(0)#transform(torch.tensor(image).to(DEVICE)).unsqueeze(0) # making it a batch | |
# sending it to model to get prediction | |
logits = loaded_model(image_tensor) # logits | |
output = F.softmax(logits.view(-1)) #F.softmax(output.flatten(), dim=-1) # | |
confidences = [(classes[i], float(output[i])) for i in range(len(classes))] | |
confidences.sort(key=lambda x: x[1], reverse=True) | |
confidences = OrderedDict(confidences[:num_classes]) | |
label = torch.argmax(output).item() | |
if grad_cam_choice: | |
print("** Before Calling **",transformed_image.shape) | |
image_list = [transformed_image.to(DEVICE),torch.tensor(label).to(DEVICE),torch.tensor(label).to(DEVICE)] | |
grad_cam_output,_ = process_gradcam_images(num_images = 1,layer = layer,opacity= opacity,image_list=image_list) | |
return confidences, grad_cam_output , gr.update(visible=True) | |
else: | |
return confidences, gr.update(visible=False),gr.update(visible=False) | |
with gr.Blocks() as demo: | |
with gr.Tab("GradCam"): | |
gr.Markdown( | |
""" | |
Visualize Class Activations Maps (helps to see what the model is actually looking at in the image) generated by the model's layer for the predicted class | |
- For existing images | |
- For new images (choose an example image or upload your own) | |
""" | |
) | |
with gr.Column(): | |
with gr.Box(): | |
radio_gradcam = gr.Radio(["Yes (View Existing Images)", "No (New or Example Images)"], label="Do you want to view existing GradCAM images?") | |
with gr.Column(): | |
with gr.Row(): | |
slider_gradcam_num_images = gr.Slider(minimum=1, maximum =10, value = 1, step =1, visible= False, interactive = False) | |
dropdown_gradcam_layer = gr.Dropdown(choices=['layer4', 'layer3'], value = "layer4", label="Please select the layer from which the GradCAM would be taken", interactive = True, visible= False) | |
slider_gradcam_opacity = gr.Slider(label ="Opacity of Images", minimum=0.05, maximum =1.00, value = 0.70, step =0.05, visible= False, interactive = True) | |
button_gradcam = gr.Button("View GradCAM Output", visible = False) | |
# txt_gradcam = gr.Textbox ("GradCAM output here" , visible = True) | |
output_gallery_gradcam=gr.Gallery(label="GradCAM Output", min_width=512,columns=4, visible = False) | |
with gr.Box(): | |
with gr.Row(): | |
with gr.Column(): | |
input_image_classify = gr.Image(label="Classification",type="pil", shape=(32, 32)) | |
slider_classify_num_classes = gr.Slider(label="Select the number of top classes to be shown",minimum=1, maximum =10, value = 3, step = 1, visible= True, interactive = True) | |
checkbox_gradcam_classify = gr.Checkbox(label="Enable GradCAM", value=True, info="Do you want to see Class Activation Maps?", visible=True) | |
# txt_classify= gr.Textbox ("Classification output here" , visible = True) | |
dropdown_gradcam_classify_layer = gr.Dropdown(choices=['layer4', 'layer3'], value = "layer4", label="Please select the layer from which the GradCAM would be taken", interactive = True, visible= True) | |
slider_gradcam_classify_opacity = gr.Slider(label ="Opacity of Images", minimum=0.05, maximum =1.00, value = 0.80, step =0.05, visible= True, interactive = True) | |
button_classify = gr.Button("Submit to Classify Image", visible = True) | |
with gr.Column(): | |
label_classify = gr.Label(num_top_classes=10, visible = True) | |
gallery_gradcam_classify = gr.Gallery(label="GradCAM Output", min_width=256,columns=1, visible = True) | |
with gr.Row(): | |
gr.Examples(['bird1.jpg','car1.jpg','deer1.jpg','frog1.jpg','plane1.jpg','ship1.jpg','truck1.jpg',"cat1.jpg","dog1.jpg","horse1.jpg"],inputs=[input_image_classify]) | |
with gr.Tab("Misclassified Examples"): | |
gr.Markdown( | |
""" | |
The AI model is not able to predict correct image labels all the time. | |
Select "Yes" to visualize the misclassified images with their model predicted label and ground truth label. | |
""" | |
) | |
with gr.Column(): | |
with gr.Box(): | |
radio_misclassified = gr.Radio(["Yes", "No"], label="Do you want to view Misclassified images?") | |
slider_misclassified_num_images = gr.Slider(minimum=1, maximum =10, value = 1, step =1, visible= False, interactive = False) | |
button_misclassified = gr.Button("View Misclassified Output", visible = False) | |
# txt_misclassified = gr.Textbox ("Misclassified output here" , visible = True) | |
output_gallery_misclassification=gr.Gallery(label="Misclassification Output (Predicted/Truth)", min_width=512,columns=5, visible = False) | |
radio_gradcam.change(fn=view_gradcam_images, inputs=radio_gradcam, outputs=[slider_gradcam_num_images, dropdown_gradcam_layer,slider_gradcam_opacity,button_gradcam, output_gallery_gradcam]) | |
button_gradcam.click(fn = process_gradcam_images, inputs = [slider_gradcam_num_images,dropdown_gradcam_layer,slider_gradcam_opacity], outputs = [output_gallery_gradcam,output_gallery_gradcam]) | |
radio_misclassified.change(fn=view_misclassified_images, inputs=radio_misclassified, outputs=[slider_misclassified_num_images,button_misclassified,output_gallery_misclassification]) | |
button_misclassified.click(fn = process_misclassified_images, inputs = [slider_misclassified_num_images], outputs = [output_gallery_misclassification,output_gallery_misclassification]) | |
button_classify.click(fn=classify_image, inputs =[input_image_classify,slider_classify_num_classes,checkbox_gradcam_classify,dropdown_gradcam_classify_layer,slider_gradcam_classify_opacity], outputs = [label_classify,gallery_gradcam_classify,gallery_gradcam_classify]) | |
demo.launch () | |