Spaces:
Sleeping
Sleeping
Commit
·
a534ad2
1
Parent(s):
fc7448c
Minor fix
Browse files- app.py +2 -2
- visualize.py +2 -0
app.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
import dataset
|
3 |
-
import visualize
|
4 |
import albumentations
|
5 |
from utils import get_misclassified_data
|
6 |
from albumentations.pytorch import ToTensorV2
|
|
|
7 |
from torchvision import transforms
|
8 |
import numpy as np
|
9 |
import gradio as gr
|
@@ -88,7 +88,7 @@ def inference(input_img, is_grad_cam=True, transparency = 0.5, target_layer_numb
|
|
88 |
# Get the misclassified data from test dataset
|
89 |
misclassified_data = get_misclassified_data(model, device, test_loader)
|
90 |
# Plot the misclassified data
|
91 |
-
|
92 |
else:
|
93 |
missclassified_images = None
|
94 |
|
|
|
1 |
import torch
|
2 |
import dataset
|
|
|
3 |
import albumentations
|
4 |
from utils import get_misclassified_data
|
5 |
from albumentations.pytorch import ToTensorV2
|
6 |
+
from visualize import display_cifar_misclassified_data
|
7 |
from torchvision import transforms
|
8 |
import numpy as np
|
9 |
import gradio as gr
|
|
|
88 |
# Get the misclassified data from test dataset
|
89 |
misclassified_data = get_misclassified_data(model, device, test_loader)
|
90 |
# Plot the misclassified data
|
91 |
+
missclassified_images = display_cifar_misclassified_data(misclassified_data, number_of_samples=num_missclassified_images)
|
92 |
else:
|
93 |
missclassified_images = None
|
94 |
|
visualize.py
CHANGED
@@ -64,6 +64,8 @@ def display_cifar_misclassified_data(data: list,
|
|
64 |
plt.xticks([])
|
65 |
plt.yticks([])
|
66 |
|
|
|
|
|
67 |
def display_gradcam_output(data: list,
|
68 |
model,
|
69 |
target_layers,
|
|
|
64 |
plt.xticks([])
|
65 |
plt.yticks([])
|
66 |
|
67 |
+
return fig
|
68 |
+
|
69 |
def display_gradcam_output(data: list,
|
70 |
model,
|
71 |
target_layers,
|