Spaces:
Sleeping
Sleeping
File size: 4,753 Bytes
d8a6e20 836b6de d8a6e20 836b6de aee2703 27b2c9b e4e9402 836b6de 032491b e4e9402 836b6de e4e9402 aee2703 e4e9402 836b6de 27b2c9b 836b6de e4e9402 836b6de a4ec3c0 836b6de e4e9402 836b6de 27b2c9b a4ec3c0 836b6de e4e9402 836b6de a4ec3c0 836b6de 27b2c9b aee2703 b78e348 fe51b53 836b6de 032491b fe51b53 836b6de 2e52fb7 d8a6e20 aee2703 836b6de d8a6e20 032491b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
import gradio as gr
# Load the pre-trained DenseNet-121 model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.densenet121(pretrained=True)
# Modify the classifier layer to output probabilities for 14 classes (pathologies)
num_classes = 14
model.classifier = nn.Sequential(
nn.Linear(model.classifier.in_features, num_classes),
nn.Sigmoid(), # Use Sigmoid for multi-label classification
)
try:
model.load_state_dict(torch.load('chexnet.pth', map_location=device))
except Exception as e:
print(f"Error loading pre-trained weights: {e}")
model.to(device)
model.eval()
# Define image transformations (resize, normalize)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Class names and their interpretations for the 14 diseases
class_names = [
'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass',
'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema',
'Emphysema', 'Fibrosis', 'Pleural Thickening', 'Hernia'
]
interpretations = {
'Atelectasis': "Partial or complete collapse of the lung.",
'Cardiomegaly': "Enlargement of the heart.",
'Effusion': "Fluid accumulation in the chest cavity.",
'Infiltration': "Substances such as fluid in the lungs.",
'Mass': "An abnormal growth in the lung.",
'Nodule': "Small round or oval-shaped growth in the lung.",
'Pneumonia': "Infection causing inflammation in the air sacs.",
'Pneumothorax': "Air in the pleural space causing lung collapse.",
'Consolidation': "Lung tissue that has filled with liquid.",
'Edema': "Excess fluid in the lungs.",
'Emphysema': "Damage to air sacs causing difficulty breathing.",
'Fibrosis': "Thickening or scarring of lung tissue.",
'Pleural Thickening': "Thickening of the pleura (lining of the lungs).",
'Hernia': "Displacement of an organ through a structure."
}
# Prediction function
def predict_disease(image):
image = transform(image).unsqueeze(0).to(device) # Transform and add batch dimension
with torch.no_grad():
outputs = model(image)
outputs = outputs.cpu().numpy().flatten()
# Result with interpretations
result = {
f"{class_name} ({interpretations[class_name]})": float(prob)
for class_name, prob in zip(class_names, outputs)
}
return result
# References to display
references = """
1. Huang, G., et al. (2017). Densely Connected Convolutional Networks. Proceedings of the IEEE conference on computer vision and pattern recognition.
2. Wang, X., et al. (2017). ChestX-ray8: Hospital-scale chest X-ray database and benchmarks on weakly-supervised classification and localization of common thorax diseases. IEEE CVPR.
3. Rajpurkar, P., et al. (2017). CheXNet: Radiologist-level pneumonia detection on chest X-rays with deep learning. arXiv preprint arXiv:1711.05225.
4. Abid, A., et al. (2019). Gradio: Hassle-Free Sharing and Testing of Machine Learning Models. arXiv preprint arXiv:1906.02569.
"""
# Gradio Interface without using deprecated parameters
interface = gr.Interface(
fn=predict_disease,
inputs=gr.components.Image(type='pil'), # Updated input component
outputs=[gr.components.Label(label="Disease Probabilities"), gr.components.Textbox(label="References", value=references, lines=10)],
title="CheXNet Pneumonia Detection",
description="""Upload a chest X-ray to detect the probability of 14 different diseases.
References:
1. Huang, G., et al. (2017). Densely Connected Convolutional Networks. Proceedings of the IEEE conference on computer vision and pattern recognition.
2. Wang, X., et al. (2017). ChestX-ray8: Hospital-scale chest X-ray database and benchmarks on weakly-supervised classification and localization of common thorax diseases. IEEE CVPR.
3. Rajpurkar, P., et al. (2017). CheXNet: Radiologist-level pneumonia detection on chest X-rays with deep learning. arXiv preprint arXiv:1711.05225.
4. Abid, A., et al. (2019). Gradio: Hassle-Free Sharing and Testing of Machine Learning Models. arXiv preprint arXiv:1906.02569.
""",
)
# Gradio Interface
interface = gr.Interface(
fn=predict_disease,
inputs=gr.components.Image(type='pil'), # Updated input component
outputs="label", # Output is a dictionary of labels with probabilities
title="CheXNet Pneumonia Detection",
description="Upload a chest X-ray to detect the probability of 14 different diseases.",
)
# Launch the Gradio app
if __name__ == "__main__":
interface.launch()
|