rtik007 commited on
Commit
aee2703
·
verified ·
1 Parent(s): fe51b53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -7
app.py CHANGED
@@ -1,16 +1,12 @@
1
  import torch
2
  import torch.nn as nn
3
  from torchvision import transforms, models
4
- from torchvision.models import DenseNet121_Weights
5
  from PIL import Image
6
  import gradio as gr
7
 
8
  # Load the pre-trained DenseNet-121 model
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
-
11
- # Use DenseNet121 with updated weights parameter
12
- weights = DenseNet121_Weights.IMAGENET1K_V1
13
- model = models.densenet121(weights=weights)
14
 
15
  # Modify the classifier layer to output probabilities for 14 classes (pathologies)
16
  num_classes = 14
@@ -21,8 +17,8 @@ model.classifier = nn.Sequential(
21
 
22
  try:
23
  model.load_state_dict(torch.load('chexnet.pth', map_location=device))
24
- except FileNotFoundError:
25
- print("Error loading pre-trained weights: 'chexnet.pth' file not found.")
26
  model.to(device)
27
  model.eval()
28
 
@@ -72,6 +68,7 @@ def predict_disease(image):
72
  }
73
  return result
74
 
 
75
  # References to display
76
  references = """
77
  1. Huang, G., et al. (2017). Densely Connected Convolutional Networks. Proceedings of the IEEE conference on computer vision and pattern recognition.
@@ -89,6 +86,15 @@ interface = gr.Interface(
89
  description="Upload a chest X-ray to detect the probability of 14 different diseases.",
90
  )
91
 
 
 
 
 
 
 
 
 
 
92
  # Launch the Gradio app
93
  if __name__ == "__main__":
94
  interface.launch()
 
1
  import torch
2
  import torch.nn as nn
3
  from torchvision import transforms, models
 
4
  from PIL import Image
5
  import gradio as gr
6
 
7
  # Load the pre-trained DenseNet-121 model
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ model = models.densenet121(pretrained=True)
 
 
 
10
 
11
  # Modify the classifier layer to output probabilities for 14 classes (pathologies)
12
  num_classes = 14
 
17
 
18
  try:
19
  model.load_state_dict(torch.load('chexnet.pth', map_location=device))
20
+ except Exception as e:
21
+ print(f"Error loading pre-trained weights: {e}")
22
  model.to(device)
23
  model.eval()
24
 
 
68
  }
69
  return result
70
 
71
+
72
  # References to display
73
  references = """
74
  1. Huang, G., et al. (2017). Densely Connected Convolutional Networks. Proceedings of the IEEE conference on computer vision and pattern recognition.
 
86
  description="Upload a chest X-ray to detect the probability of 14 different diseases.",
87
  )
88
 
89
+ # Gradio Interface
90
+ interface = gr.Interface(
91
+ fn=predict_disease,
92
+ inputs=gr.components.Image(type='pil'), # Updated input component
93
+ outputs="label", # Output is a dictionary of labels with probabilities
94
+ title="CheXNet Pneumonia Detection",
95
+ description="Upload a chest X-ray to detect the probability of 14 different diseases.",
96
+ )
97
+
98
  # Launch the Gradio app
99
  if __name__ == "__main__":
100
  interface.launch()