rtik007 commited on
Commit
d8a6e20
·
verified ·
1 Parent(s): 17e428a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -92
app.py CHANGED
@@ -1,92 +1,92 @@
1
- import torch
2
- from transformers import ViTForImageClassification, ViTFeatureExtractor
3
- from PIL import Image
4
- import gradio as gr
5
- import numpy as np
6
- import matplotlib.pyplot as plt
7
-
8
- # Load the pretrained Vision Transformer model and feature extractor
9
- model_name = "google/vit-base-patch16-224"
10
- model = ViTForImageClassification.from_pretrained(model_name)
11
- feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
12
- model.eval()
13
-
14
- # Function to apply Grad-CAM visualization
15
- def generate_grad_cam(image, target_layer):
16
- # Preprocess the image
17
- inputs = feature_extractor(images=image, return_tensors="pt")
18
- input_tensor = inputs['pixel_values']
19
-
20
- # Forward pass to get logits
21
- input_tensor.requires_grad = True
22
- outputs = model(input_tensor)
23
-
24
- # Get the target score
25
- score = outputs.logits[0].max()
26
-
27
- # Backpropagate to get gradients
28
- model.zero_grad()
29
- score.backward()
30
-
31
- # Get the gradients and activations from the target layer
32
- gradients = model.get_input_embeddings().weight.grad
33
- activations = model.get_input_embeddings().weight.data
34
-
35
- # Calculate Grad-CAM
36
- pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
37
- for i in range(activations.size(1)):
38
- activations[:, i, :, :] *= pooled_gradients[i]
39
-
40
- heatmap = torch.mean(activations, dim=1).squeeze()
41
- heatmap = np.maximum(heatmap.detach().numpy(), 0)
42
- heatmap = heatmap / np.max(heatmap)
43
-
44
- return heatmap
45
-
46
- # Prediction and Grad-CAM function
47
- def predict_and_explain(image):
48
- # Predict the class
49
- inputs = feature_extractor(images=image, return_tensors="pt")
50
- with torch.no_grad():
51
- outputs = model(**inputs)
52
-
53
- logits = outputs.logits
54
- predicted_class_idx = logits.argmax(-1).item()
55
-
56
- # Predefined medical conditions (adjust based on your dataset)
57
- labels = ["Class 1 - Normal", "Class 2 - Condition A", "Class 3 - Condition B"]
58
- predicted_label = labels[predicted_class_idx]
59
-
60
- # Generate Grad-CAM heatmap
61
- heatmap = generate_grad_cam(image, target_layer="vit.encoder.layer.11.output")
62
-
63
- # Visualize the heatmap on the original image
64
- img = np.array(image)
65
- heatmap_resized = np.array(Image.fromarray(heatmap).resize((img.shape[1], img.shape[0])))
66
-
67
- # Overlay heatmap on the original image
68
- plt.imshow(img)
69
- plt.imshow(heatmap_resized, cmap='jet', alpha=0.5)
70
- plt.axis('off')
71
-
72
- # Save the overlayed image
73
- plt.savefig("grad_cam_result.png")
74
-
75
- return predicted_label, "grad_cam_result.png"
76
-
77
- # Gradio interface
78
- interface = gr.Interface(
79
- fn=predict_and_explain,
80
- inputs=gr.inputs.Image(type="pil"),
81
- outputs=[
82
- "text",
83
- gr.outputs.Image(type="file", label="Grad-CAM Visualization")
84
- ],
85
- title="Medical Image Analysis Tool with Explainability",
86
- description="Upload an X-ray or MRI image to get a prediction for a medical condition with explainability through Grad-CAM.",
87
- live=True
88
- )
89
-
90
- # Launch the app
91
- if __name__ == "__main__":
92
- interface.launch()
 
1
+ import torch
2
+ from transformers import ViTForImageClassification, ViTImageProcessor
3
+ from PIL import Image
4
+ import gradio as gr
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+
8
+ # Load the pretrained Vision Transformer model and image processor
9
+ model_name = "google/vit-base-patch16-224"
10
+ model = ViTForImageClassification.from_pretrained(model_name)
11
+ image_processor = ViTImageProcessor.from_pretrained(model_name)
12
+ model.eval()
13
+
14
+ # Function to apply Grad-CAM visualization
15
+ def generate_grad_cam(image, target_layer):
16
+ # Preprocess the image
17
+ inputs = image_processor(images=image, return_tensors="pt")
18
+ input_tensor = inputs['pixel_values']
19
+
20
+ # Forward pass to get logits
21
+ input_tensor.requires_grad = True
22
+ outputs = model(input_tensor)
23
+
24
+ # Get the target score
25
+ score = outputs.logits[0].max()
26
+
27
+ # Backpropagate to get gradients
28
+ model.zero_grad()
29
+ score.backward()
30
+
31
+ # Get the gradients and activations from the target layer
32
+ gradients = model.get_input_embeddings().weight.grad
33
+ activations = model.get_input_embeddings().weight.data
34
+
35
+ # Calculate Grad-CAM
36
+ pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
37
+ for i in range(activations.size(1)):
38
+ activations[:, i, :, :] *= pooled_gradients[i]
39
+
40
+ heatmap = torch.mean(activations, dim=1).squeeze()
41
+ heatmap = np.maximum(heatmap.detach().numpy(), 0)
42
+ heatmap = heatmap / np.max(heatmap)
43
+
44
+ return heatmap
45
+
46
+ # Prediction and Grad-CAM function
47
+ def predict_and_explain(image):
48
+ # Predict the class
49
+ inputs = image_processor(images=image, return_tensors="pt")
50
+ with torch.no_grad():
51
+ outputs = model(**inputs)
52
+
53
+ logits = outputs.logits
54
+ predicted_class_idx = logits.argmax(-1).item()
55
+
56
+ # Predefined medical conditions (adjust based on your dataset)
57
+ labels = ["Class 1 - Normal", "Class 2 - Condition A", "Class 3 - Condition B"]
58
+ predicted_label = labels[predicted_class_idx]
59
+
60
+ # Generate Grad-CAM heatmap
61
+ heatmap = generate_grad_cam(image, target_layer="vit.encoder.layer.11.output")
62
+
63
+ # Visualize the heatmap on the original image
64
+ img = np.array(image)
65
+ heatmap_resized = np.array(Image.fromarray(heatmap).resize((img.shape[1], img.shape[0])))
66
+
67
+ # Overlay heatmap on the original image
68
+ plt.imshow(img)
69
+ plt.imshow(heatmap_resized, cmap='jet', alpha=0.5)
70
+ plt.axis('off')
71
+
72
+ # Save the overlayed image
73
+ plt.savefig("grad_cam_result.png")
74
+
75
+ return predicted_label, "grad_cam_result.png"
76
+
77
+ # Gradio interface
78
+ interface = gr.Interface(
79
+ fn=predict_and_explain,
80
+ inputs=gr.Image(type="pil"),
81
+ outputs=[
82
+ "text",
83
+ gr.Image(type="file", label="Grad-CAM Visualization")
84
+ ],
85
+ title="Medical Image Analysis Tool with Explainability",
86
+ description="Upload an X-ray or MRI image to get a prediction for a medical condition with explainability through Grad-CAM.",
87
+ live=True
88
+ )
89
+
90
+ # Launch the app
91
+ if __name__ == "__main__":
92
+ interface.launch()