rtik007 commited on
Commit
fe51b53
·
verified ·
1 Parent(s): b78e348

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -1,12 +1,16 @@
1
  import torch
2
  import torch.nn as nn
3
  from torchvision import transforms, models
 
4
  from PIL import Image
5
  import gradio as gr
6
 
7
  # Load the pre-trained DenseNet-121 model
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
- model = models.densenet121(pretrained=True)
 
 
 
10
 
11
  # Modify the classifier layer to output probabilities for 14 classes (pathologies)
12
  num_classes = 14
@@ -17,8 +21,8 @@ model.classifier = nn.Sequential(
17
 
18
  try:
19
  model.load_state_dict(torch.load('chexnet.pth', map_location=device))
20
- except Exception as e:
21
- print(f"Error loading pre-trained weights: {e}")
22
  model.to(device)
23
  model.eval()
24
 
@@ -76,14 +80,13 @@ references = """
76
  4. Abid, A., et al. (2019). Gradio: Hassle-Free Sharing and Testing of Machine Learning Models. arXiv preprint arXiv:1906.02569.
77
  """
78
 
79
- # Gradio Interface
80
  interface = gr.Interface(
81
  fn=predict_disease,
82
  inputs=gr.components.Image(type='pil'), # Updated input component
83
- outputs="label", # Output is a dictionary of labels with probabilities
84
  title="CheXNet Pneumonia Detection",
85
  description="Upload a chest X-ray to detect the probability of 14 different diseases.",
86
- additional_outputs=[gr.components.Textbox(label="References", value=references, lines=10)], # Display references
87
  )
88
 
89
  # Launch the Gradio app
 
1
  import torch
2
  import torch.nn as nn
3
  from torchvision import transforms, models
4
+ from torchvision.models import DenseNet121_Weights
5
  from PIL import Image
6
  import gradio as gr
7
 
8
  # Load the pre-trained DenseNet-121 model
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ # Use DenseNet121 with updated weights parameter
12
+ weights = DenseNet121_Weights.IMAGENET1K_V1
13
+ model = models.densenet121(weights=weights)
14
 
15
  # Modify the classifier layer to output probabilities for 14 classes (pathologies)
16
  num_classes = 14
 
21
 
22
  try:
23
  model.load_state_dict(torch.load('chexnet.pth', map_location=device))
24
+ except FileNotFoundError:
25
+ print("Error loading pre-trained weights: 'chexnet.pth' file not found.")
26
  model.to(device)
27
  model.eval()
28
 
 
80
  4. Abid, A., et al. (2019). Gradio: Hassle-Free Sharing and Testing of Machine Learning Models. arXiv preprint arXiv:1906.02569.
81
  """
82
 
83
+ # Gradio Interface without using deprecated parameters
84
  interface = gr.Interface(
85
  fn=predict_disease,
86
  inputs=gr.components.Image(type='pil'), # Updated input component
87
+ outputs=[gr.components.Label(label="Disease Probabilities"), gr.components.Textbox(label="References", value=references, lines=10)],
88
  title="CheXNet Pneumonia Detection",
89
  description="Upload a chest X-ray to detect the probability of 14 different diseases.",
 
90
  )
91
 
92
  # Launch the Gradio app