AkashDataScience commited on
Commit
a534ad2
·
1 Parent(s): fc7448c
Files changed (2) hide show
  1. app.py +2 -2
  2. visualize.py +2 -0
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
  import dataset
3
- import visualize
4
  import albumentations
5
  from utils import get_misclassified_data
6
  from albumentations.pytorch import ToTensorV2
 
7
  from torchvision import transforms
8
  import numpy as np
9
  import gradio as gr
@@ -88,7 +88,7 @@ def inference(input_img, is_grad_cam=True, transparency = 0.5, target_layer_numb
88
  # Get the misclassified data from test dataset
89
  misclassified_data = get_misclassified_data(model, device, test_loader)
90
  # Plot the misclassified data
91
- visualize.display_cifar_misclassified_data(misclassified_data, number_of_samples=num_missclassified_images)
92
  else:
93
  missclassified_images = None
94
 
 
1
  import torch
2
  import dataset
 
3
  import albumentations
4
  from utils import get_misclassified_data
5
  from albumentations.pytorch import ToTensorV2
6
+ from visualize import display_cifar_misclassified_data
7
  from torchvision import transforms
8
  import numpy as np
9
  import gradio as gr
 
88
  # Get the misclassified data from test dataset
89
  misclassified_data = get_misclassified_data(model, device, test_loader)
90
  # Plot the misclassified data
91
+ missclassified_images = display_cifar_misclassified_data(misclassified_data, number_of_samples=num_missclassified_images)
92
  else:
93
  missclassified_images = None
94
 
visualize.py CHANGED
@@ -64,6 +64,8 @@ def display_cifar_misclassified_data(data: list,
64
  plt.xticks([])
65
  plt.yticks([])
66
 
 
 
67
  def display_gradcam_output(data: list,
68
  model,
69
  target_layers,
 
64
  plt.xticks([])
65
  plt.yticks([])
66
 
67
+ return fig
68
+
69
  def display_gradcam_output(data: list,
70
  model,
71
  target_layers,