AkashDataScience commited on
Commit
cebd76b
·
1 Parent(s): 275a5a2

Code optimization

Browse files
Files changed (2) hide show
  1. app.py +6 -14
  2. utils.py +3 -0
app.py CHANGED
@@ -27,7 +27,8 @@ test_loader = dataset.get_test_data_loader(**dataloader_args)
27
  classes = ('plane', 'car', 'bird', 'cat', 'deer',
28
  'dog', 'frog', 'horse', 'ship', 'truck')
29
 
30
- cache_dict = {"misclassified_images": None, "is_misclassified_images": None, "num_misclassified_images": None}
 
31
 
32
  def resize_image_pil(image, new_width, new_height):
33
 
@@ -86,20 +87,11 @@ def inference(input_img, is_grad_cam=True, transparency = 0.5, target_layer_numb
86
  # Pick the top n predictions
87
  top_n_confidences = dict(list(sorted_confidences.items())[:top_predictions])
88
 
89
- if (is_misclassified_images != cache_dict["is_misclassified_images"] or
90
- num_misclassified_images != cache_dict["num_misclassified_images"]):
91
- cache_dict["is_misclassified_images"] = is_misclassified_images
92
- cache_dict["num_misclassified_images"] = num_misclassified_images
93
- if is_misclassified_images:
94
- # Get the misclassified data from test dataset
95
- misclassified_data = get_misclassified_data(model, device, test_loader)
96
- # Plot the misclassified data
97
- misclassified_images = display_cifar_misclassified_data(misclassified_data, number_of_samples=num_misclassified_images)
98
- cache_dict["misclassified_images"] = misclassified_images
99
- else:
100
- misclassified_images = None
101
  else:
102
- misclassified_images = cache_dict["misclassified_images"]
103
 
104
  return classes[prediction[0].item()], visualization, top_n_confidences, misclassified_images
105
 
 
27
  classes = ('plane', 'car', 'bird', 'cat', 'deer',
28
  'dog', 'frog', 'horse', 'ship', 'truck')
29
 
30
+ # Get the misclassified data from test dataset
31
+ misclassified_data = get_misclassified_data(model, device, test_loader)
32
 
33
  def resize_image_pil(image, new_width, new_height):
34
 
 
87
  # Pick the top n predictions
88
  top_n_confidences = dict(list(sorted_confidences.items())[:top_predictions])
89
 
90
+ if is_misclassified_images:
91
+ # Plot the misclassified data
92
+ misclassified_images = display_cifar_misclassified_data(misclassified_data, number_of_samples=num_misclassified_images)
 
 
 
 
 
 
 
 
 
93
  else:
94
+ misclassified_images = None
95
 
96
  return classes[prediction[0].item()], visualization, top_n_confidences, misclassified_images
97
 
utils.py CHANGED
@@ -51,6 +51,9 @@ def get_misclassified_data(model, device, test_loader):
51
  with torch.no_grad():
52
  # Extract images, labels in a batch
53
  for data, target in test_loader:
 
 
 
54
 
55
  # Migrate the data to the device
56
  data, target = data.to(device), target.to(device)
 
51
  with torch.no_grad():
52
  # Extract images, labels in a batch
53
  for data, target in test_loader:
54
+
55
+ if len(misclassified_data) > 40:
56
+ break
57
 
58
  # Migrate the data to the device
59
  data, target = data.to(device), target.to(device)