File size: 5,772 Bytes
fb7a270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
os.environ['TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD'] = '1'

import numpy as np
import torch
import gradio as gr
from PIL import Image
import re
import fitz  # PyMuPDF
from torchvision import transforms

# ==============================
# CONFIG
# ==============================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Your 23 custom labels
LABELS = [
    "Pneumonia", "Tuberculosis", "Lung Cancer", "Pulmonary Fibrosis",
    "COPD", "COVID-19 lung infection", "Pleural Effusion", "Atelectasis",
    "Cardiomegaly", "Rib Fracture", "Spinal Fracture", "Osteoporosis",
    "Arthritis", "Bone Tumor", "Scoliosis", "Dental Caries",
    "Impacted Tooth", "Jaw Fracture", "Intestinal Obstruction",
    "Kidney Stone", "Gallstone", "Foreign Object"
]

# Disease info (examples, can be extended)
DISEASE_INFO = {
    "Pneumonia": {
        "description": "Infection of the lung causing inflammation.",
        "cause": "Bacterial, viral, or fungal pathogens.",
        "recommendation": "Consult a doctor, antibiotics/antivirals if confirmed."
    },
    "Tuberculosis": {
        "description": "Bacterial infection by Mycobacterium tuberculosis.",
        "cause": "Airborne spread from infected person.",
        "recommendation": "Seek TB specialist, long-term antibiotics."
    },
    # ... add info for all 23 labels ...
}

def get_disease_info(label):
    d = DISEASE_INFO.get(label)
    if d:
        return (
            f"<b>{label}</b>: {d['description']}<br>"
            f"<b>Possible Causes:</b> {d['cause']}<br>"
            f"<b>Recommendation:</b> {d['recommendation']}"
        )
    return f"<b>{label}</b>: No extra info available. Please consult a radiologist."

# ==============================
# MODEL LOADING
# ==============================
# Replace this with your own Hugging Face model ID
MODEL_ID = "your-username/your-xray-multilabel-model"

# Example: assume model is a torch.nn.Module with sigmoid output
# The model should accept [1, 3, 224, 224] tensor and output [1, len(LABELS)]
model = torch.hub.load("pytorch/vision", "resnet18", pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, len(LABELS))
model.load_state_dict(torch.load("model_weights.pth", map_location=DEVICE))
model.to(DEVICE).eval()

# ==============================
# IMAGE PREPROCESSING
# ==============================
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

def preprocess_image(img: Image.Image) -> torch.Tensor:
    return transform(img).unsqueeze(0).to(DEVICE)

# ==============================
# XRAY ANALYSIS
# ==============================
def analyse_xray(img: Image.Image):
    if img is None:
        return "Please upload an X-ray image.", None
    try:
        x = preprocess_image(img)
        with torch.no_grad():
            outputs = model(x)
            probs = torch.sigmoid(outputs)[0] * 100
        topk = torch.topk(probs, 5)

        html = "<h3>🩺 Top 5 Predictions</h3><table border='1'><tr><th>Condition</th><th>Confidence</th><th>Details</th></tr>"
        for idx in topk.indices:
            label = LABELS[idx]
            html += (
                f"<tr><td>{label}</td>"
                f"<td>{probs[idx]:.1f}%</td>"
                f"<td>{get_disease_info(label)}</td></tr>"
            )
        html += "</table>"
        return html, img.resize((224, 224))
    except Exception as e:
        return f"Error processing image: {str(e)}", None

# ==============================
# PDF REPORT ANALYSIS
# ==============================
def analyse_report(file):
    if file is None:
        return "Please upload a PDF report."
    try:
        doc = fitz.open(file.name)
        text = "\n".join(page.get_text() for page in doc)
        doc.close()
        found = []
        for label in LABELS:
            if re.search(rf"\b{label.lower()}\b", text.lower()):
                found.append(label)
        if found:
            html = "<h3>πŸ“ƒ Findings Detected in Report:</h3><ul>"
            for label in found:
                html += f"<li>{get_disease_info(label)}</li>"
            html += "</ul>"
        else:
            html = "<p>No known conditions detected from report text.</p>"
        return html
    except Exception as e:
        return f"Error processing PDF: {str(e)}"

# ==============================
# GRADIO UI
# ==============================
with gr.Blocks(title="🩻 Multi-Xray AI") as demo:
    gr.Markdown(
        "## 🩻 Multi-Xray AI\n"
        "Detect and classify 23 different medical conditions from various X-ray types."
    )
    with gr.Tabs():
        with gr.Tab("πŸ” X-ray Analysis"):
            x_input = gr.Image(label="Upload X-ray", type="pil")
            x_out_html = gr.HTML()
            x_out_image = gr.Image(label="Resized (224x224)")
            analyze_btn = gr.Button("Analyze X-ray")
            clear_btn = gr.Button("Clear")
            analyze_btn.click(analyse_xray, inputs=x_input, outputs=[x_out_html, x_out_image])
            clear_btn.click(lambda: (None, "", None), None, [x_input, x_out_html, x_out_image])

        with gr.Tab("πŸ“„ PDF Report Analysis"):
            pdf_input = gr.File(file_types=[".pdf"], label="Upload PDF Medical Report")
            pdf_output = gr.HTML()
            analyze_pdf_btn = gr.Button("Analyze Report")
            clear_pdf_btn = gr.Button("Clear")
            analyze_pdf_btn.click(analyse_report, inputs=pdf_input, outputs=pdf_output)
            clear_pdf_btn.click(lambda: (None, ""), None, [pdf_input, pdf_output])

if __name__ == "__main__":
    demo.launch(server_port=int(os.getenv("PORT", 7860)), show_error=True)