import os os.environ['TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD'] = '1' import numpy as np import torch import gradio as gr from PIL import Image import re import fitz # PyMuPDF from torchvision import transforms # ============================== # CONFIG # ============================== DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Your 23 custom labels LABELS = [ "Pneumonia", "Tuberculosis", "Lung Cancer", "Pulmonary Fibrosis", "COPD", "COVID-19 lung infection", "Pleural Effusion", "Atelectasis", "Cardiomegaly", "Rib Fracture", "Spinal Fracture", "Osteoporosis", "Arthritis", "Bone Tumor", "Scoliosis", "Dental Caries", "Impacted Tooth", "Jaw Fracture", "Intestinal Obstruction", "Kidney Stone", "Gallstone", "Foreign Object" ] # Disease info (examples, can be extended) DISEASE_INFO = { "Pneumonia": { "description": "Infection of the lung causing inflammation.", "cause": "Bacterial, viral, or fungal pathogens.", "recommendation": "Consult a doctor, antibiotics/antivirals if confirmed." }, "Tuberculosis": { "description": "Bacterial infection by Mycobacterium tuberculosis.", "cause": "Airborne spread from infected person.", "recommendation": "Seek TB specialist, long-term antibiotics." }, # ... add info for all 23 labels ... } def get_disease_info(label): d = DISEASE_INFO.get(label) if d: return ( f"{label}: {d['description']}
" f"Possible Causes: {d['cause']}
" f"Recommendation: {d['recommendation']}" ) return f"{label}: No extra info available. Please consult a radiologist." # ============================== # MODEL LOADING # ============================== # Replace this with your own Hugging Face model ID MODEL_ID = "your-username/your-xray-multilabel-model" # Example: assume model is a torch.nn.Module with sigmoid output # The model should accept [1, 3, 224, 224] tensor and output [1, len(LABELS)] model = torch.hub.load("pytorch/vision", "resnet18", pretrained=False) model.fc = torch.nn.Linear(model.fc.in_features, len(LABELS)) model.load_state_dict(torch.load("model_weights.pth", map_location=DEVICE)) model.to(DEVICE).eval() # ============================== # IMAGE PREPROCESSING # ============================== transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.Grayscale(num_output_channels=3), transforms.ToTensor(), transforms.Normalize(mean=[0.5]*3, std=[0.5]*3) ]) def preprocess_image(img: Image.Image) -> torch.Tensor: return transform(img).unsqueeze(0).to(DEVICE) # ============================== # XRAY ANALYSIS # ============================== def analyse_xray(img: Image.Image): if img is None: return "Please upload an X-ray image.", None try: x = preprocess_image(img) with torch.no_grad(): outputs = model(x) probs = torch.sigmoid(outputs)[0] * 100 topk = torch.topk(probs, 5) html = "

🩺 Top 5 Predictions

" for idx in topk.indices: label = LABELS[idx] html += ( f"" f"" f"" ) html += "
ConditionConfidenceDetails
{label}{probs[idx]:.1f}%{get_disease_info(label)}
" return html, img.resize((224, 224)) except Exception as e: return f"Error processing image: {str(e)}", None # ============================== # PDF REPORT ANALYSIS # ============================== def analyse_report(file): if file is None: return "Please upload a PDF report." try: doc = fitz.open(file.name) text = "\n".join(page.get_text() for page in doc) doc.close() found = [] for label in LABELS: if re.search(rf"\b{label.lower()}\b", text.lower()): found.append(label) if found: html = "

📃 Findings Detected in Report:

" else: html = "

No known conditions detected from report text.

" return html except Exception as e: return f"Error processing PDF: {str(e)}" # ============================== # GRADIO UI # ============================== with gr.Blocks(title="🩻 Multi-Xray AI") as demo: gr.Markdown( "## 🩻 Multi-Xray AI\n" "Detect and classify 23 different medical conditions from various X-ray types." ) with gr.Tabs(): with gr.Tab("🔍 X-ray Analysis"): x_input = gr.Image(label="Upload X-ray", type="pil") x_out_html = gr.HTML() x_out_image = gr.Image(label="Resized (224x224)") analyze_btn = gr.Button("Analyze X-ray") clear_btn = gr.Button("Clear") analyze_btn.click(analyse_xray, inputs=x_input, outputs=[x_out_html, x_out_image]) clear_btn.click(lambda: (None, "", None), None, [x_input, x_out_html, x_out_image]) with gr.Tab("📄 PDF Report Analysis"): pdf_input = gr.File(file_types=[".pdf"], label="Upload PDF Medical Report") pdf_output = gr.HTML() analyze_pdf_btn = gr.Button("Analyze Report") clear_pdf_btn = gr.Button("Clear") analyze_pdf_btn.click(analyse_report, inputs=pdf_input, outputs=pdf_output) clear_pdf_btn.click(lambda: (None, ""), None, [pdf_input, pdf_output]) if __name__ == "__main__": demo.launch(server_port=int(os.getenv("PORT", 7860)), show_error=True)