RadiologyScanAI / app.py
Sanjayraju30's picture
Create app.py
fb7a270 verified
import os
os.environ['TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD'] = '1'
import numpy as np
import torch
import gradio as gr
from PIL import Image
import re
import fitz # PyMuPDF
from torchvision import transforms
# ==============================
# CONFIG
# ==============================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Your 23 custom labels
LABELS = [
"Pneumonia", "Tuberculosis", "Lung Cancer", "Pulmonary Fibrosis",
"COPD", "COVID-19 lung infection", "Pleural Effusion", "Atelectasis",
"Cardiomegaly", "Rib Fracture", "Spinal Fracture", "Osteoporosis",
"Arthritis", "Bone Tumor", "Scoliosis", "Dental Caries",
"Impacted Tooth", "Jaw Fracture", "Intestinal Obstruction",
"Kidney Stone", "Gallstone", "Foreign Object"
]
# Disease info (examples, can be extended)
DISEASE_INFO = {
"Pneumonia": {
"description": "Infection of the lung causing inflammation.",
"cause": "Bacterial, viral, or fungal pathogens.",
"recommendation": "Consult a doctor, antibiotics/antivirals if confirmed."
},
"Tuberculosis": {
"description": "Bacterial infection by Mycobacterium tuberculosis.",
"cause": "Airborne spread from infected person.",
"recommendation": "Seek TB specialist, long-term antibiotics."
},
# ... add info for all 23 labels ...
}
def get_disease_info(label):
d = DISEASE_INFO.get(label)
if d:
return (
f"<b>{label}</b>: {d['description']}<br>"
f"<b>Possible Causes:</b> {d['cause']}<br>"
f"<b>Recommendation:</b> {d['recommendation']}"
)
return f"<b>{label}</b>: No extra info available. Please consult a radiologist."
# ==============================
# MODEL LOADING
# ==============================
# Replace this with your own Hugging Face model ID
MODEL_ID = "your-username/your-xray-multilabel-model"
# Example: assume model is a torch.nn.Module with sigmoid output
# The model should accept [1, 3, 224, 224] tensor and output [1, len(LABELS)]
model = torch.hub.load("pytorch/vision", "resnet18", pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, len(LABELS))
model.load_state_dict(torch.load("model_weights.pth", map_location=DEVICE))
model.to(DEVICE).eval()
# ==============================
# IMAGE PREPROCESSING
# ==============================
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])
def preprocess_image(img: Image.Image) -> torch.Tensor:
return transform(img).unsqueeze(0).to(DEVICE)
# ==============================
# XRAY ANALYSIS
# ==============================
def analyse_xray(img: Image.Image):
if img is None:
return "Please upload an X-ray image.", None
try:
x = preprocess_image(img)
with torch.no_grad():
outputs = model(x)
probs = torch.sigmoid(outputs)[0] * 100
topk = torch.topk(probs, 5)
html = "<h3>🩺 Top 5 Predictions</h3><table border='1'><tr><th>Condition</th><th>Confidence</th><th>Details</th></tr>"
for idx in topk.indices:
label = LABELS[idx]
html += (
f"<tr><td>{label}</td>"
f"<td>{probs[idx]:.1f}%</td>"
f"<td>{get_disease_info(label)}</td></tr>"
)
html += "</table>"
return html, img.resize((224, 224))
except Exception as e:
return f"Error processing image: {str(e)}", None
# ==============================
# PDF REPORT ANALYSIS
# ==============================
def analyse_report(file):
if file is None:
return "Please upload a PDF report."
try:
doc = fitz.open(file.name)
text = "\n".join(page.get_text() for page in doc)
doc.close()
found = []
for label in LABELS:
if re.search(rf"\b{label.lower()}\b", text.lower()):
found.append(label)
if found:
html = "<h3>πŸ“ƒ Findings Detected in Report:</h3><ul>"
for label in found:
html += f"<li>{get_disease_info(label)}</li>"
html += "</ul>"
else:
html = "<p>No known conditions detected from report text.</p>"
return html
except Exception as e:
return f"Error processing PDF: {str(e)}"
# ==============================
# GRADIO UI
# ==============================
with gr.Blocks(title="🩻 Multi-Xray AI") as demo:
gr.Markdown(
"## 🩻 Multi-Xray AI\n"
"Detect and classify 23 different medical conditions from various X-ray types."
)
with gr.Tabs():
with gr.Tab("πŸ” X-ray Analysis"):
x_input = gr.Image(label="Upload X-ray", type="pil")
x_out_html = gr.HTML()
x_out_image = gr.Image(label="Resized (224x224)")
analyze_btn = gr.Button("Analyze X-ray")
clear_btn = gr.Button("Clear")
analyze_btn.click(analyse_xray, inputs=x_input, outputs=[x_out_html, x_out_image])
clear_btn.click(lambda: (None, "", None), None, [x_input, x_out_html, x_out_image])
with gr.Tab("πŸ“„ PDF Report Analysis"):
pdf_input = gr.File(file_types=[".pdf"], label="Upload PDF Medical Report")
pdf_output = gr.HTML()
analyze_pdf_btn = gr.Button("Analyze Report")
clear_pdf_btn = gr.Button("Clear")
analyze_pdf_btn.click(analyse_report, inputs=pdf_input, outputs=pdf_output)
clear_pdf_btn.click(lambda: (None, ""), None, [pdf_input, pdf_output])
if __name__ == "__main__":
demo.launch(server_port=int(os.getenv("PORT", 7860)), show_error=True)