jainvrushab commited on
Commit
1023ae4
·
verified ·
1 Parent(s): 427f647

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -24
app.py CHANGED
@@ -1,40 +1,62 @@
1
  import gradio as gr
2
  import torch
 
3
  from PIL import Image
4
- from torchvision import transforms
5
 
6
  # Set the device
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
 
 
 
9
  # Load the model
10
- model = torch.load("transfer_balanced_learning_model.pth", map_location=torch.device('cpu'))
11
- model = model.to(device)
12
- model.eval()
 
 
 
 
 
 
 
13
 
14
- # Define the transforms for input images
15
- transform = transforms.Compose([
16
- transforms.Resize((224, 224)),
17
- transforms.ToTensor(),
18
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
19
- ])
20
 
 
21
  def predict(image):
22
- # Preprocess the image
23
- image = transform(image)
24
- image = image.unsqueeze(0).to(device)
 
 
 
 
 
 
25
  with torch.no_grad():
26
- output = model(image)
27
- _, predicted = torch.max(output, 1)
28
- class_names = ["COVID", "Lung_Opacity", "No_Tumor", "Normal", "Tumor", "Viral_Pneumonia"]
29
- return class_names[predicted.item()]
 
 
 
30
 
31
- demo = gr.Interface(
 
32
  fn=predict,
33
- inputs="image",
34
- outputs="text",
35
- title="Image Classification Model",
36
- description="Upload an image to classify it into one of the following classes: COVID, Lung_Opacity, No_Tumor, Normal, Tumor, Viral_Pneumonia"
 
 
 
37
  )
38
 
39
- # Launch the Gradio app
40
- demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from torchvision import models, transforms
4
  from PIL import Image
5
+
6
 
7
  # Set the device
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
10
+ # Define the class names
11
+ class_names = ['COVID', 'Lung_Opacity', 'No_Tumor', 'Normal', 'Tumor', 'Viral_Pneumonia']
12
+
13
  # Load the model
14
+ def load_model(model_path, num_classes):
15
+ model = models.efficientnet_b0(weights=None)
16
+ model.classifier = torch.nn.Sequential(
17
+ torch.nn.Dropout(p=0.2, inplace=True),
18
+ torch.nn.Linear(in_features=1280, out_features=num_classes, bias=True)
19
+ )
20
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
21
+ model.eval()
22
+ return model
23
+
24
 
25
+ model_path = 'transfer_balanced_learning_model.pth'
26
+ num_classes = len(class_names)
27
+ model = load_model(model_path, num_classes)
 
 
 
28
 
29
+ # Function to make predictions
30
  def predict(image):
31
+ preprocess = transforms.Compose([
32
+ transforms.Resize(256),
33
+ transforms.CenterCrop(224),
34
+ transforms.ToTensor(),
35
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
36
+ ])
37
+ image = preprocess(image)
38
+ input_batch = image.unsqueeze(0)
39
+
40
  with torch.no_grad():
41
+ output = model(input_batch)
42
+
43
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
44
+ _, predicted_idx = torch.max(output, 1)
45
+ predicted_label = class_names[predicted_idx.item()]
46
+
47
+ return {class_names[i]: float(prob) for i, prob in enumerate(probabilities)}, predicted_label
48
 
49
+ # Create the Gradio interface
50
+ iface = gr.Interface(
51
  fn=predict,
52
+ inputs=gr.Image(type="pil"),
53
+ outputs=[
54
+ gr.Label(num_top_classes=len(class_names)),
55
+ gr.Label(label="Predicted Class")
56
+ ],
57
+ title="Medical Image Classification",
58
+ description="Upload a medical image to classify it into one of the following categories: COVID, Lung Opacity, No Tumor, Normal, Tumor, or Viral Pneumonia."
59
  )
60
 
61
+ # Launch the interface
62
+ iface.launch()