DHEIVER commited on
Commit
82af901
·
1 Parent(s): 6ae70cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -1,9 +1,10 @@
1
  import gradio as gr
2
  import torch
3
- from PIL import Image
 
4
 
5
  # Load the model
6
- model = torch.hub.load('pytorch/vision:v0.9.0', 'resnet101', pretrained=False)
7
  model.load_state_dict(torch.load('resnet101_pneumonia.pt', map_location=torch.device('cpu')))
8
  model.eval()
9
 
@@ -14,8 +15,13 @@ class_names = ["normal", "pneumonia"]
14
  def predict(img):
15
  # Preprocess the image
16
  img = img.convert("RGB")
17
- transform = torch.hub.load('pytorch/vision:v0.9.0', 'transforms', pretrained=True)
18
- img = transform(img)
 
 
 
 
 
19
  img = img.unsqueeze(0)
20
 
21
  # Make prediction
@@ -35,7 +41,7 @@ def predict(img):
35
  # Create the Gradio interface
36
  iface = gr.Interface(
37
  fn=predict,
38
- inputs=gr.inputs.Image(type="pil", label="Input Image"),
39
  outputs=gr.outputs.Label(num_top_classes=2, label="Predicted Class"),
40
  title="PneumoniaDetector 👁",
41
  description="A ResNet101 computer vision model to detect pneumonia",
 
1
  import gradio as gr
2
  import torch
3
+ from torchvision import transforms
4
+ import numpy as np
5
 
6
  # Load the model
7
+ model = torch.hub.load('pytorch/vision:v0.9.0', 'resnet101', pretrained=True)
8
  model.load_state_dict(torch.load('resnet101_pneumonia.pt', map_location=torch.device('cpu')))
9
  model.eval()
10
 
 
15
  def predict(img):
16
  # Preprocess the image
17
  img = img.convert("RGB")
18
+ preprocess = transforms.Compose([
19
+ transforms.Resize(256),
20
+ transforms.CenterCrop(224),
21
+ transforms.ToTensor(),
22
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
23
+ ])
24
+ img = preprocess(img)
25
  img = img.unsqueeze(0)
26
 
27
  # Make prediction
 
41
  # Create the Gradio interface
42
  iface = gr.Interface(
43
  fn=predict,
44
+ inputs=gr.inputs.Image(type="numpy", label="Input Image"),
45
  outputs=gr.outputs.Label(num_top_classes=2, label="Predicted Class"),
46
  title="PneumoniaDetector 👁",
47
  description="A ResNet101 computer vision model to detect pneumonia",