jainvrushab commited on
Commit
427f647
·
verified ·
1 Parent(s): c9ec770

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -3,8 +3,12 @@ import torch
3
  from PIL import Image
4
  from torchvision import transforms
5
 
 
 
 
6
  # Load the model
7
- model = torch.load("transfer_balanced_learning_model.pth")
 
8
  model.eval()
9
 
10
  # Define the transforms for input images
@@ -17,14 +21,13 @@ transform = transforms.Compose([
17
  def predict(image):
18
  # Preprocess the image
19
  image = transform(image)
20
- image = image.unsqueeze(0)
21
  with torch.no_grad():
22
  output = model(image)
23
  _, predicted = torch.max(output, 1)
24
  class_names = ["COVID", "Lung_Opacity", "No_Tumor", "Normal", "Tumor", "Viral_Pneumonia"]
25
  return class_names[predicted.item()]
26
 
27
-
28
  demo = gr.Interface(
29
  fn=predict,
30
  inputs="image",
@@ -34,4 +37,4 @@ demo = gr.Interface(
34
  )
35
 
36
  # Launch the Gradio app
37
- demo.launch()
 
3
  from PIL import Image
4
  from torchvision import transforms
5
 
6
+ # Set the device
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
  # Load the model
10
+ model = torch.load("transfer_balanced_learning_model.pth", map_location=torch.device('cpu'))
11
+ model = model.to(device)
12
  model.eval()
13
 
14
  # Define the transforms for input images
 
21
  def predict(image):
22
  # Preprocess the image
23
  image = transform(image)
24
+ image = image.unsqueeze(0).to(device)
25
  with torch.no_grad():
26
  output = model(image)
27
  _, predicted = torch.max(output, 1)
28
  class_names = ["COVID", "Lung_Opacity", "No_Tumor", "Normal", "Tumor", "Viral_Pneumonia"]
29
  return class_names[predicted.item()]
30
 
 
31
  demo = gr.Interface(
32
  fn=predict,
33
  inputs="image",
 
37
  )
38
 
39
  # Launch the Gradio app
40
+ demo.launch()