Spaces:
Build error
Build error
File size: 5,772 Bytes
fb7a270 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import os
os.environ['TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD'] = '1'
import numpy as np
import torch
import gradio as gr
from PIL import Image
import re
import fitz # PyMuPDF
from torchvision import transforms
# ==============================
# CONFIG
# ==============================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Your 23 custom labels
LABELS = [
"Pneumonia", "Tuberculosis", "Lung Cancer", "Pulmonary Fibrosis",
"COPD", "COVID-19 lung infection", "Pleural Effusion", "Atelectasis",
"Cardiomegaly", "Rib Fracture", "Spinal Fracture", "Osteoporosis",
"Arthritis", "Bone Tumor", "Scoliosis", "Dental Caries",
"Impacted Tooth", "Jaw Fracture", "Intestinal Obstruction",
"Kidney Stone", "Gallstone", "Foreign Object"
]
# Disease info (examples, can be extended)
DISEASE_INFO = {
"Pneumonia": {
"description": "Infection of the lung causing inflammation.",
"cause": "Bacterial, viral, or fungal pathogens.",
"recommendation": "Consult a doctor, antibiotics/antivirals if confirmed."
},
"Tuberculosis": {
"description": "Bacterial infection by Mycobacterium tuberculosis.",
"cause": "Airborne spread from infected person.",
"recommendation": "Seek TB specialist, long-term antibiotics."
},
# ... add info for all 23 labels ...
}
def get_disease_info(label):
d = DISEASE_INFO.get(label)
if d:
return (
f"<b>{label}</b>: {d['description']}<br>"
f"<b>Possible Causes:</b> {d['cause']}<br>"
f"<b>Recommendation:</b> {d['recommendation']}"
)
return f"<b>{label}</b>: No extra info available. Please consult a radiologist."
# ==============================
# MODEL LOADING
# ==============================
# Replace this with your own Hugging Face model ID
MODEL_ID = "your-username/your-xray-multilabel-model"
# Example: assume model is a torch.nn.Module with sigmoid output
# The model should accept [1, 3, 224, 224] tensor and output [1, len(LABELS)]
model = torch.hub.load("pytorch/vision", "resnet18", pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, len(LABELS))
model.load_state_dict(torch.load("model_weights.pth", map_location=DEVICE))
model.to(DEVICE).eval()
# ==============================
# IMAGE PREPROCESSING
# ==============================
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])
def preprocess_image(img: Image.Image) -> torch.Tensor:
return transform(img).unsqueeze(0).to(DEVICE)
# ==============================
# XRAY ANALYSIS
# ==============================
def analyse_xray(img: Image.Image):
if img is None:
return "Please upload an X-ray image.", None
try:
x = preprocess_image(img)
with torch.no_grad():
outputs = model(x)
probs = torch.sigmoid(outputs)[0] * 100
topk = torch.topk(probs, 5)
html = "<h3>π©Ί Top 5 Predictions</h3><table border='1'><tr><th>Condition</th><th>Confidence</th><th>Details</th></tr>"
for idx in topk.indices:
label = LABELS[idx]
html += (
f"<tr><td>{label}</td>"
f"<td>{probs[idx]:.1f}%</td>"
f"<td>{get_disease_info(label)}</td></tr>"
)
html += "</table>"
return html, img.resize((224, 224))
except Exception as e:
return f"Error processing image: {str(e)}", None
# ==============================
# PDF REPORT ANALYSIS
# ==============================
def analyse_report(file):
if file is None:
return "Please upload a PDF report."
try:
doc = fitz.open(file.name)
text = "\n".join(page.get_text() for page in doc)
doc.close()
found = []
for label in LABELS:
if re.search(rf"\b{label.lower()}\b", text.lower()):
found.append(label)
if found:
html = "<h3>π Findings Detected in Report:</h3><ul>"
for label in found:
html += f"<li>{get_disease_info(label)}</li>"
html += "</ul>"
else:
html = "<p>No known conditions detected from report text.</p>"
return html
except Exception as e:
return f"Error processing PDF: {str(e)}"
# ==============================
# GRADIO UI
# ==============================
with gr.Blocks(title="π©» Multi-Xray AI") as demo:
gr.Markdown(
"## π©» Multi-Xray AI\n"
"Detect and classify 23 different medical conditions from various X-ray types."
)
with gr.Tabs():
with gr.Tab("π X-ray Analysis"):
x_input = gr.Image(label="Upload X-ray", type="pil")
x_out_html = gr.HTML()
x_out_image = gr.Image(label="Resized (224x224)")
analyze_btn = gr.Button("Analyze X-ray")
clear_btn = gr.Button("Clear")
analyze_btn.click(analyse_xray, inputs=x_input, outputs=[x_out_html, x_out_image])
clear_btn.click(lambda: (None, "", None), None, [x_input, x_out_html, x_out_image])
with gr.Tab("π PDF Report Analysis"):
pdf_input = gr.File(file_types=[".pdf"], label="Upload PDF Medical Report")
pdf_output = gr.HTML()
analyze_pdf_btn = gr.Button("Analyze Report")
clear_pdf_btn = gr.Button("Clear")
analyze_pdf_btn.click(analyse_report, inputs=pdf_input, outputs=pdf_output)
clear_pdf_btn.click(lambda: (None, ""), None, [pdf_input, pdf_output])
if __name__ == "__main__":
demo.launch(server_port=int(os.getenv("PORT", 7860)), show_error=True)
|