Spaces:
Sleeping
Sleeping
File size: 2,024 Bytes
d8a6e20 836b6de d8a6e20 836b6de 27b2c9b e4e9402 836b6de 032491b e4e9402 836b6de e4e9402 836b6de 27b2c9b 836b6de e4e9402 836b6de e4e9402 836b6de 27b2c9b 836b6de e4e9402 836b6de 27b2c9b 836b6de 032491b 836b6de e4e9402 d8a6e20 836b6de d8a6e20 032491b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
import gradio as gr
# Load the pre-trained DenseNet-121 model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.densenet121(pretrained=True)
# Modify the classifier layer to output probabilities for 14 classes (pathologies)
num_classes = 14
model.classifier = nn.Sequential(
nn.Linear(model.classifier.in_features, num_classes),
nn.Sigmoid(), # Use Sigmoid for multi-label classification
)
try:
model.load_state_dict(torch.load('chexnet.pth', map_location=device))
except Exception as e:
print(f"Error loading pre-trained weights: {e}")
model.to(device)
model.eval()
# Define image transformations (resize, normalize)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Class names for the 14 diseases (labels from ChestX-ray14 dataset)
class_names = [
'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass',
'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema',
'Emphysema', 'Fibrosis', 'Pleural Thickening', 'Hernia'
]
# Prediction function
def predict_disease(image):
image = transform(image).unsqueeze(0).to(device) # Transform and add batch dimension
with torch.no_grad():
outputs = model(image)
outputs = outputs.cpu().numpy().flatten()
result = {class_name: float(prob) for class_name, prob in zip(class_names, outputs)}
return result
# Gradio Interface
interface = gr.Interface(
fn=predict_disease,
inputs=gr.components.Image(type='pil'), # Updated input component
outputs="label", # Output is a dictionary of labels with probabilities
title="CheXNet Pneumonia Detection",
description="Upload a chest X-ray to detect the probability of 14 different diseases.",
)
# Launch the Gradio app
if __name__ == "__main__":
interface.launch()
|