File size: 4,753 Bytes
d8a6e20
836b6de
 
d8a6e20
 
 
836b6de
 
aee2703
27b2c9b
e4e9402
836b6de
032491b
 
e4e9402
836b6de
e4e9402
 
 
aee2703
 
e4e9402
836b6de
27b2c9b
836b6de
 
 
 
e4e9402
836b6de
 
a4ec3c0
836b6de
e4e9402
 
836b6de
 
27b2c9b
a4ec3c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
836b6de
 
 
e4e9402
836b6de
 
 
a4ec3c0
 
 
 
 
 
836b6de
27b2c9b
aee2703
b78e348
 
 
 
 
 
 
 
fe51b53
836b6de
 
032491b
fe51b53
836b6de
2e52fb7
 
 
 
 
 
 
d8a6e20
 
aee2703
 
 
 
 
 
 
 
 
836b6de
d8a6e20
032491b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
import gradio as gr

# Load the pre-trained DenseNet-121 model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.densenet121(pretrained=True)

# Modify the classifier layer to output probabilities for 14 classes (pathologies)
num_classes = 14
model.classifier = nn.Sequential(
    nn.Linear(model.classifier.in_features, num_classes),
    nn.Sigmoid(),  # Use Sigmoid for multi-label classification
)

try:
    model.load_state_dict(torch.load('chexnet.pth', map_location=device))
except Exception as e:
    print(f"Error loading pre-trained weights: {e}")
model.to(device)
model.eval()

# Define image transformations (resize, normalize)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Class names and their interpretations for the 14 diseases
class_names = [
    'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass',
    'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema',
    'Emphysema', 'Fibrosis', 'Pleural Thickening', 'Hernia'
]

interpretations = {
    'Atelectasis': "Partial or complete collapse of the lung.",
    'Cardiomegaly': "Enlargement of the heart.",
    'Effusion': "Fluid accumulation in the chest cavity.",
    'Infiltration': "Substances such as fluid in the lungs.",
    'Mass': "An abnormal growth in the lung.",
    'Nodule': "Small round or oval-shaped growth in the lung.",
    'Pneumonia': "Infection causing inflammation in the air sacs.",
    'Pneumothorax': "Air in the pleural space causing lung collapse.",
    'Consolidation': "Lung tissue that has filled with liquid.",
    'Edema': "Excess fluid in the lungs.",
    'Emphysema': "Damage to air sacs causing difficulty breathing.",
    'Fibrosis': "Thickening or scarring of lung tissue.",
    'Pleural Thickening': "Thickening of the pleura (lining of the lungs).",
    'Hernia': "Displacement of an organ through a structure."
}

# Prediction function
def predict_disease(image):
    image = transform(image).unsqueeze(0).to(device)  # Transform and add batch dimension

    with torch.no_grad():
        outputs = model(image)
        outputs = outputs.cpu().numpy().flatten()
    
    # Result with interpretations
    result = {
        f"{class_name} ({interpretations[class_name]})": float(prob)
        for class_name, prob in zip(class_names, outputs)
    }
    return result


# References to display
references = """
1. Huang, G., et al. (2017). Densely Connected Convolutional Networks. Proceedings of the IEEE conference on computer vision and pattern recognition.
2. Wang, X., et al. (2017). ChestX-ray8: Hospital-scale chest X-ray database and benchmarks on weakly-supervised classification and localization of common thorax diseases. IEEE CVPR.
3. Rajpurkar, P., et al. (2017). CheXNet: Radiologist-level pneumonia detection on chest X-rays with deep learning. arXiv preprint arXiv:1711.05225.
4. Abid, A., et al. (2019). Gradio: Hassle-Free Sharing and Testing of Machine Learning Models. arXiv preprint arXiv:1906.02569.
"""

# Gradio Interface without using deprecated parameters
interface = gr.Interface(
    fn=predict_disease,
    inputs=gr.components.Image(type='pil'),  # Updated input component
    outputs=[gr.components.Label(label="Disease Probabilities"), gr.components.Textbox(label="References", value=references, lines=10)],
    title="CheXNet Pneumonia Detection",
    description="""Upload a chest X-ray to detect the probability of 14 different diseases.
    References:
    1. Huang, G., et al. (2017). Densely Connected Convolutional Networks. Proceedings of the IEEE conference on computer vision and pattern recognition.
    2. Wang, X., et al. (2017). ChestX-ray8: Hospital-scale chest X-ray database and benchmarks on weakly-supervised classification and localization of common thorax diseases. IEEE CVPR.
    3. Rajpurkar, P., et al. (2017). CheXNet: Radiologist-level pneumonia detection on chest X-rays with deep learning. arXiv preprint arXiv:1711.05225.
    4. Abid, A., et al. (2019). Gradio: Hassle-Free Sharing and Testing of Machine Learning Models. arXiv preprint arXiv:1906.02569.
    """,
)

# Gradio Interface
interface = gr.Interface(
    fn=predict_disease,
    inputs=gr.components.Image(type='pil'),  # Updated input component
    outputs="label",  # Output is a dictionary of labels with probabilities
    title="CheXNet Pneumonia Detection",
    description="Upload a chest X-ray to detect the probability of 14 different diseases.",
)

# Launch the Gradio app
if __name__ == "__main__":
    interface.launch()