AkashDataScience commited on
Commit
30b6d93
·
1 Parent(s): 47b7204

Added Top N classes

Browse files
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -39,7 +39,7 @@ def resize_image_pil(image, new_width, new_height):
39
 
40
  return resized
41
 
42
- def inference(input_img, transparency = 0.5, is_grad_cam=True, target_layer_number = -1):
43
  input_img = resize_image_pil(input_img, 32, 32)
44
 
45
  input_img = np.array(input_img)
@@ -62,7 +62,14 @@ def inference(input_img, transparency = 0.5, is_grad_cam=True, target_layer_numb
62
  visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
63
  else:
64
  visualization = None
65
- return classes[prediction[0].item()], visualization, confidences
 
 
 
 
 
 
 
66
 
67
  title = "CIFAR10 trained on ResNet18 Model with GradCAM"
68
  description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
@@ -73,12 +80,13 @@ demo = gr.Interface(
73
  gr.Image(width=256, height=256, label="Input Image"),
74
  gr.Slider(0, 1, value = 0.5, label="Overall Opacity of Image"),
75
  gr.Checkbox(label="Show GradCAM"),
76
- gr.Slider(-2, -1, value = -2, step=1, label="Which Layer?")
 
77
  ],
78
  outputs = [
79
  "text",
80
  gr.Image(width=256, height=256, label="Output"),
81
- gr.Label(num_top_classes=3)
82
  ],
83
  title = title,
84
  description = description,
 
39
 
40
  return resized
41
 
42
+ def inference(input_img, transparency = 0.5, is_grad_cam=True, target_layer_number = -1, top_predictions=3):
43
  input_img = resize_image_pil(input_img, 32, 32)
44
 
45
  input_img = np.array(input_img)
 
62
  visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
63
  else:
64
  visualization = None
65
+
66
+ # Sort the confidences dictionary based on confidence values
67
+ sorted_confidences = dict(sorted(confidences.items(), key=lambda item: item[1], reverse=True))
68
+
69
+ # Pick the top n predictions
70
+ top_n_confidences = dict(list(sorted_confidences.items())[:top_predictions])
71
+
72
+ return classes[prediction[0].item()], visualization, top_n_confidences
73
 
74
  title = "CIFAR10 trained on ResNet18 Model with GradCAM"
75
  description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
 
80
  gr.Image(width=256, height=256, label="Input Image"),
81
  gr.Slider(0, 1, value = 0.5, label="Overall Opacity of Image"),
82
  gr.Checkbox(label="Show GradCAM"),
83
+ gr.Slider(-2, -1, value = -2, step=1, label="Which Layer?"),
84
+ gr.Slider(2, 10, value=3, step=1, label="Number of Top Classes")
85
  ],
86
  outputs = [
87
  "text",
88
  gr.Image(width=256, height=256, label="Output"),
89
+ gr.Label(label="Top Classes")
90
  ],
91
  title = title,
92
  description = description,