rtik007 commited on
Commit
836b6de
·
verified ·
1 Parent(s): 4ef367f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -114
app.py CHANGED
@@ -1,125 +1,57 @@
1
  import torch
2
- from transformers import ViTForImageClassification, ViTImageProcessor
 
3
  from PIL import Image
4
  import gradio as gr
5
- import numpy as np
6
- import matplotlib.pyplot as plt
7
 
8
- # Load the pretrained Vision Transformer model and image processor
9
- model_name = "google/vit-base-patch16-224"
10
- try:
11
- model = ViTForImageClassification.from_pretrained(model_name)
12
- except Exception as e:
13
- print(f"Error loading model: {e}")
14
- image_processor = ViTImageProcessor.from_pretrained(model_name)
15
 
16
- # NIH Chest X-ray predefined conditions
17
- labels = [
18
- "Atelectasis", "Cardiomegaly", "Effusion", "Infiltration", "Mass", "Nodule",
19
- "Pneumonia", "Pneumothorax", "Consolidation", "Edema", "Emphysema",
20
- "Fibrosis", "Pleural Thickening", "Hernia"
21
- ]
 
 
 
22
 
23
- # Function to apply Grad-CAM visualization
24
- def generate_grad_cam(image, target_layer):
25
- try:
26
- # Convert image to RGB if necessary
27
- if image.mode != 'RGB':
28
- image = image.convert('RGB')
29
-
30
- # Preprocess the image
31
- inputs = image_processor(images=image, return_tensors="pt")
32
-
33
- # Forward pass to get logits
34
- input_tensor = inputs["pixel_values"]
35
- input_tensor.requires_grad = True # Enable gradient tracking
36
- outputs = model(input_tensor)
37
- logits = outputs.logits
38
-
39
- # Get the predicted class and calculate gradients
40
- predicted_class = logits.argmax(-1)
41
- class_score = logits[0, predicted_class]
42
- class_score.backward()
43
-
44
- # Get gradients and weights from the target layer
45
- gradients = model.get_input_embeddings().weight.grad
46
- pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
47
-
48
- # Apply Grad-CAM calculation (modify this part as per the model architecture)
49
- cam = torch.mean(pooled_gradients * inputs["pixel_values"], dim=1).squeeze()
50
- cam = torch.clamp(cam, min=0).numpy() # Ensure non-negative values
51
-
52
- return cam, predicted_class.item(), None # No error
53
- except Exception as e:
54
- error_message = f"Error generating Grad-CAM: {e}"
55
- print(error_message)
56
- return None, None, error_message
57
 
58
- # Function to predict classes and visualize Grad-CAM
59
- def predict_and_explain(image):
60
- try:
61
- # Convert image to RGB if necessary
62
- if image.mode != 'RGB':
63
- image = image.convert('RGB')
64
-
65
- # Preprocess the image
66
- inputs = image_processor(images=image, return_tensors="pt")
67
-
68
- # Forward pass to get logits
69
- input_tensor = inputs["pixel_values"]
70
- outputs = model(input_tensor)
71
- logits = outputs.logits
72
-
73
- predicted_class = logits.argmax(-1).item()
74
- cam_map, _, grad_cam_error = generate_grad_cam(image, "pooler_output")
75
-
76
- # Check for Grad-CAM errors
77
- if grad_cam_error is not None:
78
- return {
79
- "predicted class": "Error during Grad-CAM generation",
80
- "Grad-CAM map": None,
81
- "error log": grad_cam_error
82
- }
83
-
84
- # Convert cam_map to a visualizable format (heatmap)
85
- if cam_map is not None:
86
- plt.imshow(cam_map, cmap='jet', alpha=0.5)
87
- plt.axis('off')
88
- plt.title(f"Grad-CAM for {labels[predicted_class]}")
89
- plt.colorbar()
90
- plt.savefig("grad_cam_output.png")
91
- plt.close()
92
-
93
- # Load the saved image to return it
94
- grad_cam_image = Image.open("grad_cam_output.png")
95
- else:
96
- grad_cam_image = None
97
-
98
- return {
99
- "predicted class": labels[predicted_class],
100
- "Grad-CAM map": grad_cam_image,
101
- "error log": "No errors"
102
- }
103
- except Exception as e:
104
- error_message = f"Error predicting and explaining: {e}"
105
- print(error_message)
106
- return {
107
- "predicted class": "Error during prediction",
108
- "Grad-CAM map": None,
109
- "error log": error_message
110
- }
111
 
112
- # Use the updated Gradio components syntax
113
- iface = gr.Interface(
114
- fn=predict_and_explain,
115
- inputs=gr.Image(type="pil"), # Proper input type for images
116
- outputs=[
117
- gr.Textbox(label="Predicted Class"),
118
- gr.Image(label="Grad-CAM Map"),
119
- gr.Textbox(label="Error Log")
120
- ],
121
- title="Chest X-ray Classification with Debugging Logs"
122
  )
123
 
 
124
  if __name__ == "__main__":
125
- iface.launch()
 
1
  import torch
2
+ import torch.nn as nn
3
+ from torchvision import transforms, models
4
  from PIL import Image
5
  import gradio as gr
 
 
6
 
7
+ # Load the pre-trained DenseNet-121 model
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ model = models.densenet121(pretrained=True)
 
 
 
 
10
 
11
+ # Modify the classifier layer to output probabilities for 14 classes (14 pathologies)
12
+ num_classes = 14
13
+ model.classifier = nn.Sequential(
14
+ nn.Linear(model.classifier.in_features, num_classes),
15
+ nn.Sigmoid() # Use Sigmoid for multi-label classification
16
+ )
17
+ model.load_state_dict(torch.load('chexnet.pth', map_location=device)) # Load your pre-trained weights
18
+ model = model.to(device)
19
+ model.eval()
20
 
21
+ # Define image transformations (resize, normalize)
22
+ transform = transforms.Compose([
23
+ transforms.Resize((224, 224)),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
26
+ ])
27
+
28
+ # Class names for the 14 diseases (labels from ChestX-ray14 dataset)
29
+ class_names = [
30
+ 'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass',
31
+ 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema',
32
+ 'Emphysema', 'Fibrosis', 'Pleural Thickening', 'Hernia'
33
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ # Prediction function
36
+ def predict_disease(image):
37
+ image = transform(image).unsqueeze(0).to(device) # Transform and add batch dimension
38
+ with torch.no_grad():
39
+ outputs = model(image)
40
+ outputs = outputs.cpu().numpy().flatten()
41
+
42
+ # Create a dictionary of disease probabilities
43
+ result = {class_name: float(prob) for class_name, prob in zip(class_names, outputs)}
44
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ # Gradio Interface
47
+ interface = gr.Interface(
48
+ fn=predict_disease,
49
+ inputs=gr.inputs.Image(type='pil'), # Input is an image
50
+ outputs="label", # Output is a dictionary of labels with probabilities
51
+ title="CheXNet Pneumonia Detection",
52
+ description="Upload a chest X-ray to detect the probability of 14 different diseases."
 
 
 
53
  )
54
 
55
+ # Launch the Gradio app
56
  if __name__ == "__main__":
57
+ interface.launch()