rtik007 commited on
Commit
9131c16
·
verified ·
1 Parent(s): 27b2c9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -17
app.py CHANGED
@@ -23,6 +23,10 @@ labels = [
23
  # Function to apply Grad-CAM visualization
24
  def generate_grad_cam(image, target_layer):
25
  try:
 
 
 
 
26
  # Preprocess the image
27
  inputs = image_processor(images=image, return_tensors="pt")
28
 
@@ -45,14 +49,19 @@ def generate_grad_cam(image, target_layer):
45
  cam = torch.mean(pooled_gradients * inputs["pixel_values"], dim=1).squeeze()
46
  cam = torch.clamp(cam, min=0).numpy() # Ensure non-negative values
47
 
48
- return cam, predicted_class.item()
49
  except Exception as e:
50
- print(f"Error generating Grad-CAM: {e}")
51
- return None
 
52
 
53
  # Function to predict classes and visualize Grad-CAM
54
  def predict_and_explain(image):
55
  try:
 
 
 
 
56
  # Preprocess the image
57
  inputs = image_processor(images=image, return_tensors="pt")
58
 
@@ -62,33 +71,54 @@ def predict_and_explain(image):
62
  logits = outputs.logits
63
 
64
  predicted_class = logits.argmax(-1).item()
65
- cam_map, _ = generate_grad_cam(image, "pooler_output")
 
 
 
 
 
 
 
 
66
 
67
  # Convert cam_map to a visualizable format (heatmap)
68
- plt.imshow(cam_map, cmap='jet', alpha=0.5)
69
- plt.axis('off')
70
- plt.title(f"Grad-CAM for {labels[predicted_class]}")
71
- plt.colorbar()
72
- plt.savefig("grad_cam_output.png")
73
- plt.close()
 
74
 
75
- # Load the saved image to return it
76
- grad_cam_image = Image.open("grad_cam_output.png")
 
 
77
 
78
  return {
79
  "predicted class": labels[predicted_class],
80
  "Grad-CAM map": grad_cam_image,
 
81
  }
82
  except Exception as e:
83
- print(f"Error predicting and explaining: {e}")
84
- return None
 
 
 
 
 
85
 
86
- # Create a Gradio interface
87
  iface = gr.Interface(
88
  fn=predict_and_explain,
89
  inputs="image",
90
- outputs=["text", "image"],
91
- title="Chest X-ray Classification"
 
 
 
 
92
  )
93
 
94
  if __name__ == "__main__":
 
23
  # Function to apply Grad-CAM visualization
24
  def generate_grad_cam(image, target_layer):
25
  try:
26
+ # Convert image to RGB if necessary
27
+ if image.mode != 'RGB':
28
+ image = image.convert('RGB')
29
+
30
  # Preprocess the image
31
  inputs = image_processor(images=image, return_tensors="pt")
32
 
 
49
  cam = torch.mean(pooled_gradients * inputs["pixel_values"], dim=1).squeeze()
50
  cam = torch.clamp(cam, min=0).numpy() # Ensure non-negative values
51
 
52
+ return cam, predicted_class.item(), None # No error
53
  except Exception as e:
54
+ error_message = f"Error generating Grad-CAM: {e}"
55
+ print(error_message)
56
+ return None, None, error_message
57
 
58
  # Function to predict classes and visualize Grad-CAM
59
  def predict_and_explain(image):
60
  try:
61
+ # Convert image to RGB if necessary
62
+ if image.mode != 'RGB':
63
+ image = image.convert('RGB')
64
+
65
  # Preprocess the image
66
  inputs = image_processor(images=image, return_tensors="pt")
67
 
 
71
  logits = outputs.logits
72
 
73
  predicted_class = logits.argmax(-1).item()
74
+ cam_map, _, grad_cam_error = generate_grad_cam(image, "pooler_output")
75
+
76
+ # Check for Grad-CAM errors
77
+ if grad_cam_error is not None:
78
+ return {
79
+ "predicted class": "Error during Grad-CAM generation",
80
+ "Grad-CAM map": None,
81
+ "error log": grad_cam_error
82
+ }
83
 
84
  # Convert cam_map to a visualizable format (heatmap)
85
+ if cam_map is not None:
86
+ plt.imshow(cam_map, cmap='jet', alpha=0.5)
87
+ plt.axis('off')
88
+ plt.title(f"Grad-CAM for {labels[predicted_class]}")
89
+ plt.colorbar()
90
+ plt.savefig("grad_cam_output.png")
91
+ plt.close()
92
 
93
+ # Load the saved image to return it
94
+ grad_cam_image = Image.open("grad_cam_output.png")
95
+ else:
96
+ grad_cam_image = None
97
 
98
  return {
99
  "predicted class": labels[predicted_class],
100
  "Grad-CAM map": grad_cam_image,
101
+ "error log": "No errors"
102
  }
103
  except Exception as e:
104
+ error_message = f"Error predicting and explaining: {e}"
105
+ print(error_message)
106
+ return {
107
+ "predicted class": "Error during prediction",
108
+ "Grad-CAM map": None,
109
+ "error log": error_message
110
+ }
111
 
112
+ # Create a Gradio interface with an error log output
113
  iface = gr.Interface(
114
  fn=predict_and_explain,
115
  inputs="image",
116
+ outputs=[
117
+ gr.outputs.Textbox(label="Predicted Class"),
118
+ gr.outputs.Image(label="Grad-CAM Map"),
119
+ gr.outputs.Textbox(label="Error Log") # Error log for debugging
120
+ ],
121
+ title="Chest X-ray Classification with Debugging Logs"
122
  )
123
 
124
  if __name__ == "__main__":