File size: 2,101 Bytes
0431e63
 
1023ae4
0431e63
1023ae4
0431e63
427f647
 
 
1023ae4
 
 
0431e63
1023ae4
 
 
 
 
 
 
 
 
 
0431e63
1023ae4
 
 
0431e63
1023ae4
0431e63
1023ae4
 
 
 
 
 
 
 
 
0431e63
1023ae4
 
 
 
 
 
 
0431e63
1023ae4
 
0431e63
1023ae4
 
 
 
 
 
 
0431e63
 
1023ae4
f447738
d2c4479
 
 
 
c883643
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
import torch
from torchvision import models, transforms
from PIL import Image


# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the class names
class_names = ['COVID', 'Lung_Opacity', 'No_Tumor', 'Normal', 'Tumor', 'Viral_Pneumonia']

# Load the model
def load_model(model_path, num_classes):
    model = models.efficientnet_b0(weights=None)
    model.classifier = torch.nn.Sequential(
        torch.nn.Dropout(p=0.2, inplace=True),
        torch.nn.Linear(in_features=1280, out_features=num_classes, bias=True)
    )
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    model.eval()
    return model


model_path = 'transfer_balanced_learning_model.pth'
num_classes = len(class_names)
model = load_model(model_path, num_classes)

# Function to make predictions
def predict(image):
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image = preprocess(image)
    input_batch = image.unsqueeze(0)

    with torch.no_grad():
        output = model(input_batch)

    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    _, predicted_idx = torch.max(output, 1)
    predicted_label = class_names[predicted_idx.item()]

    return {class_names[i]: float(prob) for i, prob in enumerate(probabilities)}, predicted_label

# Create the Gradio interface
iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Label(num_top_classes=len(class_names)),
        gr.Label(label="Predicted Class")
    ],
    title="Medical Image Classification",
    description="Upload a medical image to classify it into one of the following categories: COVID, Lung Opacity, No Tumor, Normal, Tumor, or Viral Pneumonia."
)

# Launch the interface
#iface.launch(share=True)
#iface.launch()

if __name__ == "__main__":
    print("Launching Gradio Demo...")
#    iface.queue()
    iface.launch()