Nateware commited on
Commit
8009765
Β·
1 Parent(s): ff22a06

Add application file

Browse files
Files changed (1) hide show
  1. app.py +241 -0
app.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import time
4
+ import gradio as gr
5
+ import requests
6
+ from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
7
+ from PIL import Image
8
+ from io import BytesIO
9
+ import logging
10
+
11
+ # Configure logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Global variables for model
16
+ model = None
17
+ processor = None
18
+ device = None
19
+
20
+ def load_model():
21
+ """Load the AI model once at startup"""
22
+ global model, processor, device
23
+
24
+ logger.info("Loading AI model...")
25
+
26
+ # Get Hugging Face token from environment
27
+ hf_token = os.environ.get('HF_TOKEN')
28
+
29
+ model_id = "mychen76/paligemma-3b-mix-448-med_30k-ct-brain"
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
32
+
33
+ logger.info(f"Using device: {device}")
34
+
35
+ try:
36
+ # Load processor and model with authentication token
37
+ processor = AutoProcessor.from_pretrained(
38
+ model_id,
39
+ token=hf_token
40
+ )
41
+ model = PaliGemmaForConditionalGeneration.from_pretrained(
42
+ model_id,
43
+ torch_dtype=dtype,
44
+ token=hf_token
45
+ ).to(device).eval()
46
+
47
+ logger.info("Model loaded successfully!")
48
+ return True
49
+
50
+ except Exception as e:
51
+ logger.error(f"Error loading model: {e}")
52
+ return False
53
+
54
+ def analyze_brain_scan(image, patient_name="", patient_age="", symptoms=""):
55
+ """Analyze brain scan image and return medical findings"""
56
+ try:
57
+ if image is None:
58
+ return "Please upload a brain scan image."
59
+
60
+ # Convert to PIL Image if needed
61
+ if not isinstance(image, Image.Image):
62
+ image = Image.fromarray(image).convert("RGB")
63
+
64
+ # Run AI inference
65
+ prompt = "<image> Findings:"
66
+ inputs = processor(
67
+ images=image,
68
+ text=prompt,
69
+ return_tensors="pt"
70
+ ).to(device, dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
71
+
72
+ with torch.no_grad():
73
+ generated_ids = model.generate(**inputs, max_new_tokens=100)
74
+
75
+ result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
76
+
77
+ # Clean up the result
78
+ if result.startswith(prompt):
79
+ result = result[len(prompt):].strip()
80
+
81
+ # Format the response
82
+ timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
83
+
84
+ formatted_result = f"""
85
+ ## Brain CT Analysis Results
86
+
87
+ **Patient Information:**
88
+ - Name: {patient_name or 'Not provided'}
89
+ - Age: {patient_age or 'Not provided'}
90
+ - Symptoms: {symptoms or 'Not provided'}
91
+ - Analysis Date: {timestamp}
92
+
93
+ **AI Findings:**
94
+ {result}
95
+
96
+ **Note:** This is an AI-generated analysis for educational purposes only.
97
+ Always consult with qualified medical professionals for actual diagnosis.
98
+ """
99
+
100
+ return formatted_result
101
+
102
+ except Exception as e:
103
+ logger.error(f"Analysis error: {e}")
104
+ return f"Error during analysis: {str(e)}"
105
+
106
+ def create_api_response(image, patient_name="", patient_age="", symptoms=""):
107
+ """Create API-compatible response for integration"""
108
+ try:
109
+ if image is None:
110
+ return {"error": "No image provided"}
111
+
112
+ # Convert to PIL Image if needed
113
+ if not isinstance(image, Image.Image):
114
+ image = Image.fromarray(image).convert("RGB")
115
+
116
+ # Run AI inference
117
+ prompt = "<image> Findings:"
118
+ inputs = processor(
119
+ images=image,
120
+ text=prompt,
121
+ return_tensors="pt"
122
+ ).to(device, dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
123
+
124
+ with torch.no_grad():
125
+ generated_ids = model.generate(**inputs, max_new_tokens=100)
126
+
127
+ result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
128
+
129
+ # Clean up the result
130
+ if result.startswith(prompt):
131
+ result = result[len(prompt):].strip()
132
+
133
+ # Create API response
134
+ response = {
135
+ "prediction": result,
136
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
137
+ "patient_info": {
138
+ "name": patient_name,
139
+ "age": patient_age,
140
+ "symptoms": symptoms
141
+ },
142
+ "model_info": {
143
+ "model_id": "mychen76/paligemma-3b-mix-448-med_30k-ct-brain",
144
+ "device": str(device)
145
+ }
146
+ }
147
+
148
+ return response
149
+
150
+ except Exception as e:
151
+ logger.error(f"API error: {e}")
152
+ return {"error": f"Analysis failed: {str(e)}"}
153
+
154
+ # Load model at startup
155
+ logger.info("Initializing Brain CT Analyzer...")
156
+ if load_model():
157
+ logger.info("Model loaded successfully!")
158
+ else:
159
+ logger.error("Failed to load model!")
160
+
161
+ # Create Gradio interface
162
+ with gr.Blocks(title="Brain CT Analyzer", theme=gr.themes.Soft()) as demo:
163
+ gr.Markdown("""
164
+ # 🧠 Brain CT Analyzer
165
+
166
+ Upload a brain CT scan image for AI-powered analysis. This tool uses the PaliGemma medical model
167
+ to provide preliminary findings.
168
+
169
+ **⚠️ Important:** This is for educational/research purposes only. Always consult qualified medical professionals.
170
+ """)
171
+
172
+ with gr.Row():
173
+ with gr.Column(scale=1):
174
+ image_input = gr.Image(
175
+ label="Upload Brain CT Scan",
176
+ type="pil",
177
+ height=400
178
+ )
179
+
180
+ with gr.Group():
181
+ patient_name = gr.Textbox(
182
+ label="Patient Name (Optional)",
183
+ placeholder="Enter patient name"
184
+ )
185
+ patient_age = gr.Textbox(
186
+ label="Patient Age (Optional)",
187
+ placeholder="Enter patient age"
188
+ )
189
+ symptoms = gr.Textbox(
190
+ label="Symptoms (Optional)",
191
+ placeholder="Describe symptoms",
192
+ lines=3
193
+ )
194
+
195
+ analyze_btn = gr.Button(
196
+ "πŸ” Analyze Brain Scan",
197
+ variant="primary",
198
+ size="lg"
199
+ )
200
+
201
+ with gr.Column(scale=1):
202
+ result_output = gr.Markdown(
203
+ label="Analysis Results",
204
+ value="Upload an image and click 'Analyze Brain Scan' to see results."
205
+ )
206
+
207
+ # API endpoint simulation
208
+ with gr.Accordion("πŸ”Œ API Response (for developers)", open=False):
209
+ api_output = gr.JSON(label="API Response Format")
210
+
211
+ # Event handlers
212
+ analyze_btn.click(
213
+ fn=analyze_brain_scan,
214
+ inputs=[image_input, patient_name, patient_age, symptoms],
215
+ outputs=result_output
216
+ )
217
+
218
+ analyze_btn.click(
219
+ fn=create_api_response,
220
+ inputs=[image_input, patient_name, patient_age, symptoms],
221
+ outputs=api_output
222
+ )
223
+
224
+ # Example images (if you have any)
225
+ gr.Markdown("""
226
+ ## πŸ“‹ Usage Instructions:
227
+ 1. Upload a brain CT scan image (JPEG or PNG)
228
+ 2. Optionally fill in patient information
229
+ 3. Click "Analyze Brain Scan" to get AI findings
230
+ 4. Review the results in the output panel
231
+
232
+ ## πŸ”— Integration:
233
+ This interface can be integrated with your medical app using the Gradio API.
234
+ """)
235
+
236
+ if __name__ == "__main__":
237
+ demo.launch(
238
+ server_name="0.0.0.0",
239
+ server_port=7860,
240
+ share=True
241
+ )