rtik007 commited on
Commit
50a18a1
·
verified ·
1 Parent(s): e873a09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -76
app.py CHANGED
@@ -7,10 +7,11 @@ import matplotlib.pyplot as plt
7
 
8
  # Load the pretrained Vision Transformer model and image processor
9
  model_name = "google/vit-base-patch16-224"
10
- model = ViTForImageClassification.from_pretrained(model_name)
 
 
 
11
  image_processor = ViTImageProcessor.from_pretrained(model_name)
12
- model.eval()
13
-
14
  # NIH Chest X-ray predefined conditions
15
  labels = [
16
  "Atelectasis",
@@ -28,82 +29,54 @@ labels = [
28
  "Pleural Thickening",
29
  "Hernia"
30
  ]
31
-
32
  # Function to apply Grad-CAM visualization
33
  def generate_grad_cam(image, target_layer):
34
- # Preprocess the image
35
- inputs = image_processor(images=image, return_tensors="pt")
36
- input_tensor = inputs['pixel_values']
37
-
38
- # Forward pass to get logits
39
- input_tensor.requires_grad = True
40
- outputs = model(input_tensor)
41
-
42
- # Get the target score
43
- score = outputs.logits[0].max()
44
-
45
- # Backpropagate to get gradients
46
- model.zero_grad()
47
- score.backward()
48
-
49
- # Get the gradients and activations from the target layer
50
- gradients = model.get_input_embeddings().weight.grad
51
- activations = model.get_input_embeddings().weight.data
52
-
53
- # Calculate Grad-CAM
54
- pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
55
- for i in range(activations.size(1)):
56
- activations[:, i, :, :] *= pooled_gradients[i]
57
-
58
- heatmap = torch.mean(activations, dim=1).squeeze()
59
- heatmap = np.maximum(heatmap.detach().numpy(), 0)
60
- heatmap = heatmap / np.max(heatmap)
61
-
62
- return heatmap
63
-
64
- # Prediction and Grad-CAM function
65
  def predict_and_explain(image):
66
- # Predict the class
67
- inputs = image_processor(images=image, return_tensors="pt")
68
- with torch.no_grad():
69
- outputs = model(**inputs)
70
-
71
- logits = outputs.logits
72
- predicted_class_idx = logits.argmax(-1).item()
73
-
74
- # Get the predicted label based on NIH Chest X-ray conditions
75
- predicted_label = labels[predicted_class_idx]
76
-
77
- # Generate Grad-CAM heatmap
78
- heatmap = generate_grad_cam(image, target_layer="vit.encoder.layer.11.output")
79
-
80
- # Visualize the heatmap on the original image
81
- img = np.array(image)
82
- heatmap_resized = np.array(Image.fromarray(heatmap).resize((img.shape[1], img.shape[0])))
83
-
84
- # Overlay heatmap on the original image
85
- plt.imshow(img)
86
- plt.imshow(heatmap_resized, cmap='jet', alpha=0.5)
87
- plt.axis('off')
88
-
89
- # Save the overlayed image
90
- plt.savefig("grad_cam_result.png")
91
-
92
- return predicted_label, "grad_cam_result.png"
93
-
94
- # Gradio interface
95
- interface = gr.Interface(
96
- fn=predict_and_explain,
97
- inputs=gr.Image(type="pil"),
98
- outputs=[
99
- "text",
100
- gr.Image(type="file", label="Grad-CAM Visualization")
101
- ],
102
- title="Medical Image Analysis Tool with NIH Chest X-ray",
103
- description="Upload a Chest X-ray image to get a prediction for common thoracic conditions based on the NIH dataset, with explainability through Grad-CAM.",
104
- live=True
105
  )
106
 
107
- # Launch the app
108
  if __name__ == "__main__":
109
- interface.launch()
 
7
 
8
  # Load the pretrained Vision Transformer model and image processor
9
  model_name = "google/vit-base-patch16-224"
10
+ try:
11
+ model = ViTForImageClassification.from_pretrained(model_name)
12
+ except Exception as e:
13
+ print(f"Error loading model: {e}")
14
  image_processor = ViTImageProcessor.from_pretrained(model_name)
 
 
15
  # NIH Chest X-ray predefined conditions
16
  labels = [
17
  "Atelectasis",
 
29
  "Pleural Thickening",
30
  "Hernia"
31
  ]
 
32
  # Function to apply Grad-CAM visualization
33
  def generate_grad_cam(image, target_layer):
34
+ try:
35
+ # Preprocess the image
36
+ inputs = image_processor(images=image, return_tensors="pt")
37
+
38
+ # Forward pass to get logits
39
+ input_tensor = inputs["pixel_values"]
40
+ outputs = model(input_tensor)
41
+ logits = outputs.logits
42
+
43
+ # Calculate Grad-CAM
44
+ cam_weights = torch.mean(torch.relu(logits), dim=(2, 3))
45
+ cam_map = (torch.unsqueeze(cam_weights, 1) *
46
+ torch.sigmoid(outputs.pooler_output)).sum(dim=1).squeeze()
47
+
48
+ return cam_map.numpy(), logits.argmax(-1)
49
+ except Exception as e:
50
+ print(f"Error generating Grad-CAM: {e}")
51
+ return None
52
+ # Function to predict classes and visualize Grad-CAM
 
 
 
 
 
 
 
 
 
 
 
 
53
  def predict_and_explain(image):
54
+ try:
55
+ # Preprocess the image
56
+ inputs = image_processor(images=image, return_tensors="pt")
57
+
58
+ # Forward pass to get logits
59
+ input_tensor = inputs["pixel_values"]
60
+ outputs = model(input_tensor)
61
+ logits = outputs.logits
62
+
63
+ predicted_class = logits.argmax(-1).item()
64
+ cam_map, _ = generate_grad_cam(image, "pooler_output")
65
+
66
+ return {
67
+ "predicted class": labels[predicted_class],
68
+ "Grad-CAM map": cam_map,
69
+ }
70
+ except Exception as e:
71
+ print(f"Error predicting and explaining: {e}")
72
+ return None
73
+ # Create a Gradio interface
74
+ iface = gr.Interface(
75
+ fn=predict_and_explain,
76
+ inputs="image",
77
+ outputs=["text", "image"],
78
+ title="Chest X-ray Classification"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  )
80
 
 
81
  if __name__ == "__main__":
82
+ iface.launch()