rtik007 commited on
Commit
27b2c9b
·
verified ·
1 Parent(s): 50a18a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -21
app.py CHANGED
@@ -12,23 +12,14 @@ try:
12
  except Exception as e:
13
  print(f"Error loading model: {e}")
14
  image_processor = ViTImageProcessor.from_pretrained(model_name)
 
15
  # NIH Chest X-ray predefined conditions
16
  labels = [
17
- "Atelectasis",
18
- "Cardiomegaly",
19
- "Effusion",
20
- "Infiltration",
21
- "Mass",
22
- "Nodule",
23
- "Pneumonia",
24
- "Pneumothorax",
25
- "Consolidation",
26
- "Edema",
27
- "Emphysema",
28
- "Fibrosis",
29
- "Pleural Thickening",
30
- "Hernia"
31
  ]
 
32
  # Function to apply Grad-CAM visualization
33
  def generate_grad_cam(image, target_layer):
34
  try:
@@ -37,18 +28,28 @@ def generate_grad_cam(image, target_layer):
37
 
38
  # Forward pass to get logits
39
  input_tensor = inputs["pixel_values"]
 
40
  outputs = model(input_tensor)
41
  logits = outputs.logits
42
 
43
- # Calculate Grad-CAM
44
- cam_weights = torch.mean(torch.relu(logits), dim=(2, 3))
45
- cam_map = (torch.unsqueeze(cam_weights, 1) *
46
- torch.sigmoid(outputs.pooler_output)).sum(dim=1).squeeze()
 
 
 
 
 
 
 
 
47
 
48
- return cam_map.numpy(), logits.argmax(-1)
49
  except Exception as e:
50
  print(f"Error generating Grad-CAM: {e}")
51
  return None
 
52
  # Function to predict classes and visualize Grad-CAM
53
  def predict_and_explain(image):
54
  try:
@@ -63,13 +64,25 @@ def predict_and_explain(image):
63
  predicted_class = logits.argmax(-1).item()
64
  cam_map, _ = generate_grad_cam(image, "pooler_output")
65
 
 
 
 
 
 
 
 
 
 
 
 
66
  return {
67
  "predicted class": labels[predicted_class],
68
- "Grad-CAM map": cam_map,
69
  }
70
  except Exception as e:
71
  print(f"Error predicting and explaining: {e}")
72
  return None
 
73
  # Create a Gradio interface
74
  iface = gr.Interface(
75
  fn=predict_and_explain,
@@ -79,4 +92,4 @@ iface = gr.Interface(
79
  )
80
 
81
  if __name__ == "__main__":
82
- iface.launch()
 
12
  except Exception as e:
13
  print(f"Error loading model: {e}")
14
  image_processor = ViTImageProcessor.from_pretrained(model_name)
15
+
16
  # NIH Chest X-ray predefined conditions
17
  labels = [
18
+ "Atelectasis", "Cardiomegaly", "Effusion", "Infiltration", "Mass", "Nodule",
19
+ "Pneumonia", "Pneumothorax", "Consolidation", "Edema", "Emphysema",
20
+ "Fibrosis", "Pleural Thickening", "Hernia"
 
 
 
 
 
 
 
 
 
 
 
21
  ]
22
+
23
  # Function to apply Grad-CAM visualization
24
  def generate_grad_cam(image, target_layer):
25
  try:
 
28
 
29
  # Forward pass to get logits
30
  input_tensor = inputs["pixel_values"]
31
+ input_tensor.requires_grad = True # Enable gradient tracking
32
  outputs = model(input_tensor)
33
  logits = outputs.logits
34
 
35
+ # Get the predicted class and calculate gradients
36
+ predicted_class = logits.argmax(-1)
37
+ class_score = logits[0, predicted_class]
38
+ class_score.backward()
39
+
40
+ # Get gradients and weights from the target layer
41
+ gradients = model.get_input_embeddings().weight.grad
42
+ pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
43
+
44
+ # Apply Grad-CAM calculation (modify this part as per the model architecture)
45
+ cam = torch.mean(pooled_gradients * inputs["pixel_values"], dim=1).squeeze()
46
+ cam = torch.clamp(cam, min=0).numpy() # Ensure non-negative values
47
 
48
+ return cam, predicted_class.item()
49
  except Exception as e:
50
  print(f"Error generating Grad-CAM: {e}")
51
  return None
52
+
53
  # Function to predict classes and visualize Grad-CAM
54
  def predict_and_explain(image):
55
  try:
 
64
  predicted_class = logits.argmax(-1).item()
65
  cam_map, _ = generate_grad_cam(image, "pooler_output")
66
 
67
+ # Convert cam_map to a visualizable format (heatmap)
68
+ plt.imshow(cam_map, cmap='jet', alpha=0.5)
69
+ plt.axis('off')
70
+ plt.title(f"Grad-CAM for {labels[predicted_class]}")
71
+ plt.colorbar()
72
+ plt.savefig("grad_cam_output.png")
73
+ plt.close()
74
+
75
+ # Load the saved image to return it
76
+ grad_cam_image = Image.open("grad_cam_output.png")
77
+
78
  return {
79
  "predicted class": labels[predicted_class],
80
+ "Grad-CAM map": grad_cam_image,
81
  }
82
  except Exception as e:
83
  print(f"Error predicting and explaining: {e}")
84
  return None
85
+
86
  # Create a Gradio interface
87
  iface = gr.Interface(
88
  fn=predict_and_explain,
 
92
  )
93
 
94
  if __name__ == "__main__":
95
+ iface.launch()