Spaces:
Build error
Build error
import os | |
os.environ['TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD'] = '1' | |
import numpy as np | |
import torch | |
import gradio as gr | |
from PIL import Image | |
import re | |
import fitz # PyMuPDF | |
from torchvision import transforms | |
# ============================== | |
# CONFIG | |
# ============================== | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Your 23 custom labels | |
LABELS = [ | |
"Pneumonia", "Tuberculosis", "Lung Cancer", "Pulmonary Fibrosis", | |
"COPD", "COVID-19 lung infection", "Pleural Effusion", "Atelectasis", | |
"Cardiomegaly", "Rib Fracture", "Spinal Fracture", "Osteoporosis", | |
"Arthritis", "Bone Tumor", "Scoliosis", "Dental Caries", | |
"Impacted Tooth", "Jaw Fracture", "Intestinal Obstruction", | |
"Kidney Stone", "Gallstone", "Foreign Object" | |
] | |
# Disease info (examples, can be extended) | |
DISEASE_INFO = { | |
"Pneumonia": { | |
"description": "Infection of the lung causing inflammation.", | |
"cause": "Bacterial, viral, or fungal pathogens.", | |
"recommendation": "Consult a doctor, antibiotics/antivirals if confirmed." | |
}, | |
"Tuberculosis": { | |
"description": "Bacterial infection by Mycobacterium tuberculosis.", | |
"cause": "Airborne spread from infected person.", | |
"recommendation": "Seek TB specialist, long-term antibiotics." | |
}, | |
# ... add info for all 23 labels ... | |
} | |
def get_disease_info(label): | |
d = DISEASE_INFO.get(label) | |
if d: | |
return ( | |
f"<b>{label}</b>: {d['description']}<br>" | |
f"<b>Possible Causes:</b> {d['cause']}<br>" | |
f"<b>Recommendation:</b> {d['recommendation']}" | |
) | |
return f"<b>{label}</b>: No extra info available. Please consult a radiologist." | |
# ============================== | |
# MODEL LOADING | |
# ============================== | |
# Replace this with your own Hugging Face model ID | |
MODEL_ID = "your-username/your-xray-multilabel-model" | |
# Example: assume model is a torch.nn.Module with sigmoid output | |
# The model should accept [1, 3, 224, 224] tensor and output [1, len(LABELS)] | |
model = torch.hub.load("pytorch/vision", "resnet18", pretrained=False) | |
model.fc = torch.nn.Linear(model.fc.in_features, len(LABELS)) | |
model.load_state_dict(torch.load("model_weights.pth", map_location=DEVICE)) | |
model.to(DEVICE).eval() | |
# ============================== | |
# IMAGE PREPROCESSING | |
# ============================== | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.Grayscale(num_output_channels=3), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5]*3, std=[0.5]*3) | |
]) | |
def preprocess_image(img: Image.Image) -> torch.Tensor: | |
return transform(img).unsqueeze(0).to(DEVICE) | |
# ============================== | |
# XRAY ANALYSIS | |
# ============================== | |
def analyse_xray(img: Image.Image): | |
if img is None: | |
return "Please upload an X-ray image.", None | |
try: | |
x = preprocess_image(img) | |
with torch.no_grad(): | |
outputs = model(x) | |
probs = torch.sigmoid(outputs)[0] * 100 | |
topk = torch.topk(probs, 5) | |
html = "<h3>π©Ί Top 5 Predictions</h3><table border='1'><tr><th>Condition</th><th>Confidence</th><th>Details</th></tr>" | |
for idx in topk.indices: | |
label = LABELS[idx] | |
html += ( | |
f"<tr><td>{label}</td>" | |
f"<td>{probs[idx]:.1f}%</td>" | |
f"<td>{get_disease_info(label)}</td></tr>" | |
) | |
html += "</table>" | |
return html, img.resize((224, 224)) | |
except Exception as e: | |
return f"Error processing image: {str(e)}", None | |
# ============================== | |
# PDF REPORT ANALYSIS | |
# ============================== | |
def analyse_report(file): | |
if file is None: | |
return "Please upload a PDF report." | |
try: | |
doc = fitz.open(file.name) | |
text = "\n".join(page.get_text() for page in doc) | |
doc.close() | |
found = [] | |
for label in LABELS: | |
if re.search(rf"\b{label.lower()}\b", text.lower()): | |
found.append(label) | |
if found: | |
html = "<h3>π Findings Detected in Report:</h3><ul>" | |
for label in found: | |
html += f"<li>{get_disease_info(label)}</li>" | |
html += "</ul>" | |
else: | |
html = "<p>No known conditions detected from report text.</p>" | |
return html | |
except Exception as e: | |
return f"Error processing PDF: {str(e)}" | |
# ============================== | |
# GRADIO UI | |
# ============================== | |
with gr.Blocks(title="π©» Multi-Xray AI") as demo: | |
gr.Markdown( | |
"## π©» Multi-Xray AI\n" | |
"Detect and classify 23 different medical conditions from various X-ray types." | |
) | |
with gr.Tabs(): | |
with gr.Tab("π X-ray Analysis"): | |
x_input = gr.Image(label="Upload X-ray", type="pil") | |
x_out_html = gr.HTML() | |
x_out_image = gr.Image(label="Resized (224x224)") | |
analyze_btn = gr.Button("Analyze X-ray") | |
clear_btn = gr.Button("Clear") | |
analyze_btn.click(analyse_xray, inputs=x_input, outputs=[x_out_html, x_out_image]) | |
clear_btn.click(lambda: (None, "", None), None, [x_input, x_out_html, x_out_image]) | |
with gr.Tab("π PDF Report Analysis"): | |
pdf_input = gr.File(file_types=[".pdf"], label="Upload PDF Medical Report") | |
pdf_output = gr.HTML() | |
analyze_pdf_btn = gr.Button("Analyze Report") | |
clear_pdf_btn = gr.Button("Clear") | |
analyze_pdf_btn.click(analyse_report, inputs=pdf_input, outputs=pdf_output) | |
clear_pdf_btn.click(lambda: (None, ""), None, [pdf_input, pdf_output]) | |
if __name__ == "__main__": | |
demo.launch(server_port=int(os.getenv("PORT", 7860)), show_error=True) | |