Sanjayraju30 commited on
Commit
fb7a270
Β·
verified Β·
1 Parent(s): 7d8697e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -0
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD'] = '1'
3
+
4
+ import numpy as np
5
+ import torch
6
+ import gradio as gr
7
+ from PIL import Image
8
+ import re
9
+ import fitz # PyMuPDF
10
+ from torchvision import transforms
11
+
12
+ # ==============================
13
+ # CONFIG
14
+ # ==============================
15
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ # Your 23 custom labels
18
+ LABELS = [
19
+ "Pneumonia", "Tuberculosis", "Lung Cancer", "Pulmonary Fibrosis",
20
+ "COPD", "COVID-19 lung infection", "Pleural Effusion", "Atelectasis",
21
+ "Cardiomegaly", "Rib Fracture", "Spinal Fracture", "Osteoporosis",
22
+ "Arthritis", "Bone Tumor", "Scoliosis", "Dental Caries",
23
+ "Impacted Tooth", "Jaw Fracture", "Intestinal Obstruction",
24
+ "Kidney Stone", "Gallstone", "Foreign Object"
25
+ ]
26
+
27
+ # Disease info (examples, can be extended)
28
+ DISEASE_INFO = {
29
+ "Pneumonia": {
30
+ "description": "Infection of the lung causing inflammation.",
31
+ "cause": "Bacterial, viral, or fungal pathogens.",
32
+ "recommendation": "Consult a doctor, antibiotics/antivirals if confirmed."
33
+ },
34
+ "Tuberculosis": {
35
+ "description": "Bacterial infection by Mycobacterium tuberculosis.",
36
+ "cause": "Airborne spread from infected person.",
37
+ "recommendation": "Seek TB specialist, long-term antibiotics."
38
+ },
39
+ # ... add info for all 23 labels ...
40
+ }
41
+
42
+ def get_disease_info(label):
43
+ d = DISEASE_INFO.get(label)
44
+ if d:
45
+ return (
46
+ f"<b>{label}</b>: {d['description']}<br>"
47
+ f"<b>Possible Causes:</b> {d['cause']}<br>"
48
+ f"<b>Recommendation:</b> {d['recommendation']}"
49
+ )
50
+ return f"<b>{label}</b>: No extra info available. Please consult a radiologist."
51
+
52
+ # ==============================
53
+ # MODEL LOADING
54
+ # ==============================
55
+ # Replace this with your own Hugging Face model ID
56
+ MODEL_ID = "your-username/your-xray-multilabel-model"
57
+
58
+ # Example: assume model is a torch.nn.Module with sigmoid output
59
+ # The model should accept [1, 3, 224, 224] tensor and output [1, len(LABELS)]
60
+ model = torch.hub.load("pytorch/vision", "resnet18", pretrained=False)
61
+ model.fc = torch.nn.Linear(model.fc.in_features, len(LABELS))
62
+ model.load_state_dict(torch.load("model_weights.pth", map_location=DEVICE))
63
+ model.to(DEVICE).eval()
64
+
65
+ # ==============================
66
+ # IMAGE PREPROCESSING
67
+ # ==============================
68
+ transform = transforms.Compose([
69
+ transforms.Resize((224, 224)),
70
+ transforms.Grayscale(num_output_channels=3),
71
+ transforms.ToTensor(),
72
+ transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
73
+ ])
74
+
75
+ def preprocess_image(img: Image.Image) -> torch.Tensor:
76
+ return transform(img).unsqueeze(0).to(DEVICE)
77
+
78
+ # ==============================
79
+ # XRAY ANALYSIS
80
+ # ==============================
81
+ def analyse_xray(img: Image.Image):
82
+ if img is None:
83
+ return "Please upload an X-ray image.", None
84
+ try:
85
+ x = preprocess_image(img)
86
+ with torch.no_grad():
87
+ outputs = model(x)
88
+ probs = torch.sigmoid(outputs)[0] * 100
89
+ topk = torch.topk(probs, 5)
90
+
91
+ html = "<h3>🩺 Top 5 Predictions</h3><table border='1'><tr><th>Condition</th><th>Confidence</th><th>Details</th></tr>"
92
+ for idx in topk.indices:
93
+ label = LABELS[idx]
94
+ html += (
95
+ f"<tr><td>{label}</td>"
96
+ f"<td>{probs[idx]:.1f}%</td>"
97
+ f"<td>{get_disease_info(label)}</td></tr>"
98
+ )
99
+ html += "</table>"
100
+ return html, img.resize((224, 224))
101
+ except Exception as e:
102
+ return f"Error processing image: {str(e)}", None
103
+
104
+ # ==============================
105
+ # PDF REPORT ANALYSIS
106
+ # ==============================
107
+ def analyse_report(file):
108
+ if file is None:
109
+ return "Please upload a PDF report."
110
+ try:
111
+ doc = fitz.open(file.name)
112
+ text = "\n".join(page.get_text() for page in doc)
113
+ doc.close()
114
+ found = []
115
+ for label in LABELS:
116
+ if re.search(rf"\b{label.lower()}\b", text.lower()):
117
+ found.append(label)
118
+ if found:
119
+ html = "<h3>πŸ“ƒ Findings Detected in Report:</h3><ul>"
120
+ for label in found:
121
+ html += f"<li>{get_disease_info(label)}</li>"
122
+ html += "</ul>"
123
+ else:
124
+ html = "<p>No known conditions detected from report text.</p>"
125
+ return html
126
+ except Exception as e:
127
+ return f"Error processing PDF: {str(e)}"
128
+
129
+ # ==============================
130
+ # GRADIO UI
131
+ # ==============================
132
+ with gr.Blocks(title="🩻 Multi-Xray AI") as demo:
133
+ gr.Markdown(
134
+ "## 🩻 Multi-Xray AI\n"
135
+ "Detect and classify 23 different medical conditions from various X-ray types."
136
+ )
137
+ with gr.Tabs():
138
+ with gr.Tab("πŸ” X-ray Analysis"):
139
+ x_input = gr.Image(label="Upload X-ray", type="pil")
140
+ x_out_html = gr.HTML()
141
+ x_out_image = gr.Image(label="Resized (224x224)")
142
+ analyze_btn = gr.Button("Analyze X-ray")
143
+ clear_btn = gr.Button("Clear")
144
+ analyze_btn.click(analyse_xray, inputs=x_input, outputs=[x_out_html, x_out_image])
145
+ clear_btn.click(lambda: (None, "", None), None, [x_input, x_out_html, x_out_image])
146
+
147
+ with gr.Tab("πŸ“„ PDF Report Analysis"):
148
+ pdf_input = gr.File(file_types=[".pdf"], label="Upload PDF Medical Report")
149
+ pdf_output = gr.HTML()
150
+ analyze_pdf_btn = gr.Button("Analyze Report")
151
+ clear_pdf_btn = gr.Button("Clear")
152
+ analyze_pdf_btn.click(analyse_report, inputs=pdf_input, outputs=pdf_output)
153
+ clear_pdf_btn.click(lambda: (None, ""), None, [pdf_input, pdf_output])
154
+
155
+ if __name__ == "__main__":
156
+ demo.launch(server_port=int(os.getenv("PORT", 7860)), show_error=True)