File size: 2,024 Bytes
d8a6e20
836b6de
 
d8a6e20
 
 
836b6de
 
 
27b2c9b
e4e9402
836b6de
032491b
 
e4e9402
836b6de
e4e9402
 
 
 
 
 
836b6de
27b2c9b
836b6de
 
 
 
e4e9402
836b6de
 
 
 
e4e9402
 
836b6de
 
27b2c9b
836b6de
 
 
e4e9402
836b6de
 
 
 
 
27b2c9b
836b6de
 
 
032491b
836b6de
 
e4e9402
d8a6e20
 
836b6de
d8a6e20
032491b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
import gradio as gr

# Load the pre-trained DenseNet-121 model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.densenet121(pretrained=True)

# Modify the classifier layer to output probabilities for 14 classes (pathologies)
num_classes = 14
model.classifier = nn.Sequential(
    nn.Linear(model.classifier.in_features, num_classes),
    nn.Sigmoid(),  # Use Sigmoid for multi-label classification
)

try:
    model.load_state_dict(torch.load('chexnet.pth', map_location=device))
except Exception as e:
    print(f"Error loading pre-trained weights: {e}")
model.to(device)
model.eval()

# Define image transformations (resize, normalize)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Class names for the 14 diseases (labels from ChestX-ray14 dataset)
class_names = [
    'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass',
    'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema',
    'Emphysema', 'Fibrosis', 'Pleural Thickening', 'Hernia'
]

# Prediction function
def predict_disease(image):
    image = transform(image).unsqueeze(0).to(device)  # Transform and add batch dimension

    with torch.no_grad():
        outputs = model(image)
        outputs = outputs.cpu().numpy().flatten()
    result = {class_name: float(prob) for class_name, prob in zip(class_names, outputs)}
    return result

# Gradio Interface
interface = gr.Interface(
    fn=predict_disease,
    inputs=gr.components.Image(type='pil'),  # Updated input component
    outputs="label",  # Output is a dictionary of labels with probabilities
    title="CheXNet Pneumonia Detection",
    description="Upload a chest X-ray to detect the probability of 14 different diseases.",
)

# Launch the Gradio app
if __name__ == "__main__":
    interface.launch()