Nateware commited on
Commit
78c8cdc
Β·
1 Parent(s): 5ac0314

should work

Browse files
Files changed (1) hide show
  1. app.py +88 -64
app.py CHANGED
@@ -16,44 +16,67 @@ logger = logging.getLogger(__name__)
16
  model = None
17
  processor = None
18
  device = None
 
19
 
20
  def load_model():
21
- """Load the AI model once at startup"""
22
- global model, processor, device
23
 
24
- logger.info("Loading AI model...")
25
-
26
- # Get Hugging Face token from environment
27
- hf_token = os.environ.get('HF_TOKEN')
28
 
 
29
  model_id = "mychen76/paligemma-3b-mix-448-med_30k-ct-brain"
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
32
 
33
  logger.info(f"Using device: {device}")
 
 
34
 
35
  try:
36
- # Load processor and model with authentication token
37
- processor = AutoProcessor.from_pretrained(
38
- model_id,
39
- token=hf_token
40
- )
41
  model = PaliGemmaForConditionalGeneration.from_pretrained(
42
  model_id,
43
- torch_dtype=dtype,
44
- token=hf_token
45
  ).to(device).eval()
46
 
47
  logger.info("Model loaded successfully!")
 
48
  return True
49
 
50
  except Exception as e:
51
  logger.error(f"Error loading model: {e}")
 
 
52
  return False
53
 
 
 
 
 
 
 
 
 
54
  def analyze_brain_scan(image, patient_name="", patient_age="", symptoms=""):
55
  """Analyze brain scan image and return medical findings"""
56
  try:
 
 
 
 
 
 
 
 
 
 
 
 
57
  if image is None:
58
  return "Please upload a brain scan image."
59
 
@@ -61,22 +84,8 @@ def analyze_brain_scan(image, patient_name="", patient_age="", symptoms=""):
61
  if not isinstance(image, Image.Image):
62
  image = Image.fromarray(image).convert("RGB")
63
 
64
- # Run AI inference
65
- prompt = "<image> Findings:"
66
- inputs = processor(
67
- images=image,
68
- text=prompt,
69
- return_tensors="pt"
70
- ).to(device, dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
71
-
72
- with torch.no_grad():
73
- generated_ids = model.generate(**inputs, max_new_tokens=100)
74
-
75
- result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
76
-
77
- # Clean up the result
78
- if result.startswith(prompt):
79
- result = result[len(prompt):].strip()
80
 
81
  # Format the response
82
  timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
@@ -106,6 +115,9 @@ Always consult with qualified medical professionals for actual diagnosis.
106
  def create_api_response(image, patient_name="", patient_age="", symptoms=""):
107
  """Create API-compatible response for integration"""
108
  try:
 
 
 
109
  if image is None:
110
  return {"error": "No image provided"}
111
 
@@ -113,24 +125,10 @@ def create_api_response(image, patient_name="", patient_age="", symptoms=""):
113
  if not isinstance(image, Image.Image):
114
  image = Image.fromarray(image).convert("RGB")
115
 
116
- # Run AI inference
117
- prompt = "<image> Findings:"
118
- inputs = processor(
119
- images=image,
120
- text=prompt,
121
- return_tensors="pt"
122
- ).to(device, dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
123
 
124
- with torch.no_grad():
125
- generated_ids = model.generate(**inputs, max_new_tokens=100)
126
-
127
- result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
128
-
129
- # Clean up the result
130
- if result.startswith(prompt):
131
- result = result[len(prompt):].strip()
132
-
133
- # Create API response
134
  response = {
135
  "prediction": result,
136
  "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
@@ -141,7 +139,8 @@ def create_api_response(image, patient_name="", patient_age="", symptoms=""):
141
  },
142
  "model_info": {
143
  "model_id": "mychen76/paligemma-3b-mix-448-med_30k-ct-brain",
144
- "device": str(device)
 
145
  }
146
  }
147
 
@@ -151,9 +150,23 @@ def create_api_response(image, patient_name="", patient_age="", symptoms=""):
151
  logger.error(f"API error: {e}")
152
  return {"error": f"Analysis failed: {str(e)}"}
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  # Load model at startup
155
- logger.info("Initializing Brain CT Analyzer...")
156
- if load_model():
 
157
  logger.info("Model loaded successfully!")
158
  else:
159
  logger.error("Failed to load model!")
@@ -169,6 +182,12 @@ with gr.Blocks(title="Brain CT Analyzer", theme=gr.themes.Soft()) as demo:
169
  **⚠️ Important:** This is for educational/research purposes only. Always consult qualified medical professionals.
170
  """)
171
 
 
 
 
 
 
 
172
  with gr.Row():
173
  with gr.Column(scale=1):
174
  image_input = gr.Image(
@@ -195,13 +214,14 @@ with gr.Blocks(title="Brain CT Analyzer", theme=gr.themes.Soft()) as demo:
195
  analyze_btn = gr.Button(
196
  "πŸ” Analyze Brain Scan",
197
  variant="primary",
198
- size="lg"
 
199
  )
200
 
201
  with gr.Column(scale=1):
202
  result_output = gr.Markdown(
203
  label="Analysis Results",
204
- value="Upload an image and click 'Analyze Brain Scan' to see results."
205
  )
206
 
207
  # API endpoint simulation
@@ -209,19 +229,20 @@ with gr.Blocks(title="Brain CT Analyzer", theme=gr.themes.Soft()) as demo:
209
  api_output = gr.JSON(label="API Response Format")
210
 
211
  # Event handlers
212
- analyze_btn.click(
213
- fn=analyze_brain_scan,
214
- inputs=[image_input, patient_name, patient_age, symptoms],
215
- outputs=result_output
216
- )
217
-
218
- analyze_btn.click(
219
- fn=create_api_response,
220
- inputs=[image_input, patient_name, patient_age, symptoms],
221
- outputs=api_output
222
- )
 
223
 
224
- # Example images (if you have any)
225
  gr.Markdown("""
226
  ## πŸ“‹ Usage Instructions:
227
  1. Upload a brain CT scan image (JPEG or PNG)
@@ -231,6 +252,9 @@ with gr.Blocks(title="Brain CT Analyzer", theme=gr.themes.Soft()) as demo:
231
 
232
  ## πŸ”— Integration:
233
  This interface can be integrated with your medical app using the Gradio API.
 
 
 
234
  """)
235
 
236
  if __name__ == "__main__":
@@ -238,4 +262,4 @@ if __name__ == "__main__":
238
  server_name="0.0.0.0",
239
  server_port=7860,
240
  share=True
241
- )
 
16
  model = None
17
  processor = None
18
  device = None
19
+ model_loaded = False
20
 
21
  def load_model():
22
+ """Load the AI model exactly like in Colab"""
23
+ global model, processor, device, model_loaded
24
 
25
+ logger.info("Loading AI model (Colab style)...")
 
 
 
26
 
27
+ # === Load AI Model === (exactly like Colab)
28
  model_id = "mychen76/paligemma-3b-mix-448-med_30k-ct-brain"
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
31
 
32
  logger.info(f"Using device: {device}")
33
+ logger.info(f"Using dtype: {dtype}")
34
+ logger.info(f"CUDA available: {torch.cuda.is_available()}")
35
 
36
  try:
37
+ # Load exactly like Colab (no token, no trust_remote_code)
38
+ logger.info("Loading processor...")
39
+ processor = AutoProcessor.from_pretrained(model_id)
40
+
41
+ logger.info("Loading model...")
42
  model = PaliGemmaForConditionalGeneration.from_pretrained(
43
  model_id,
44
+ torch_dtype=dtype
 
45
  ).to(device).eval()
46
 
47
  logger.info("Model loaded successfully!")
48
+ model_loaded = True
49
  return True
50
 
51
  except Exception as e:
52
  logger.error(f"Error loading model: {e}")
53
+ logger.error(f"Error type: {type(e)}")
54
+ model_loaded = False
55
  return False
56
 
57
+ def run_model(img):
58
+ """Run model inference exactly like Colab"""
59
+ prompt = "<image> Findings:"
60
+ inputs = processor(images=img, text=prompt, return_tensors="pt").to(device, dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
61
+ generated_ids = model.generate(**inputs, max_new_tokens=100)
62
+ result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
63
+ return result
64
+
65
  def analyze_brain_scan(image, patient_name="", patient_age="", symptoms=""):
66
  """Analyze brain scan image and return medical findings"""
67
  try:
68
+ if not model_loaded or model is None:
69
+ return """
70
+ ## ⚠️ Model Loading Error
71
+
72
+ The AI model is not available. This could be due to:
73
+ - Model loading issues
74
+ - Memory limitations
75
+ - Network connectivity
76
+
77
+ Please check the logs or try refreshing.
78
+ """
79
+
80
  if image is None:
81
  return "Please upload a brain scan image."
82
 
 
84
  if not isinstance(image, Image.Image):
85
  image = Image.fromarray(image).convert("RGB")
86
 
87
+ # Run AI inference using Colab method
88
+ result = run_model(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  # Format the response
91
  timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
 
115
  def create_api_response(image, patient_name="", patient_age="", symptoms=""):
116
  """Create API-compatible response for integration"""
117
  try:
118
+ if not model_loaded or model is None:
119
+ return {"error": "Model not loaded"}
120
+
121
  if image is None:
122
  return {"error": "No image provided"}
123
 
 
125
  if not isinstance(image, Image.Image):
126
  image = Image.fromarray(image).convert("RGB")
127
 
128
+ # Run AI inference using Colab method
129
+ result = run_model(image)
 
 
 
 
 
130
 
131
+ # Create API response (matching your original format)
 
 
 
 
 
 
 
 
 
132
  response = {
133
  "prediction": result,
134
  "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
 
139
  },
140
  "model_info": {
141
  "model_id": "mychen76/paligemma-3b-mix-448-med_30k-ct-brain",
142
+ "device": str(device),
143
+ "model_loaded": model_loaded
144
  }
145
  }
146
 
 
150
  logger.error(f"API error: {e}")
151
  return {"error": f"Analysis failed: {str(e)}"}
152
 
153
+ def get_model_status():
154
+ """Get current model status"""
155
+ return f"""
156
+ ## πŸ€– Model Status
157
+
158
+ - **Model Loaded**: {model_loaded}
159
+ - **Device**: {device}
160
+ - **CUDA Available**: {torch.cuda.is_available()}
161
+ - **Model Object**: {type(model).__name__ if model else 'None'}
162
+ - **Processor Object**: {type(processor).__name__ if processor else 'None'}
163
+ - **PyTorch Version**: {torch.__version__}
164
+ """
165
+
166
  # Load model at startup
167
+ logger.info("Initializing Brain CT Analyzer (Colab Style)...")
168
+ load_success = load_model()
169
+ if load_success:
170
  logger.info("Model loaded successfully!")
171
  else:
172
  logger.error("Failed to load model!")
 
182
  **⚠️ Important:** This is for educational/research purposes only. Always consult qualified medical professionals.
183
  """)
184
 
185
+ # Model status section
186
+ with gr.Accordion("πŸ”§ Model Status", open=not model_loaded):
187
+ status_output = gr.Markdown(value=get_model_status())
188
+ refresh_btn = gr.Button("πŸ”„ Refresh Status")
189
+ refresh_btn.click(fn=get_model_status, outputs=status_output)
190
+
191
  with gr.Row():
192
  with gr.Column(scale=1):
193
  image_input = gr.Image(
 
214
  analyze_btn = gr.Button(
215
  "πŸ” Analyze Brain Scan",
216
  variant="primary",
217
+ size="lg",
218
+ interactive=model_loaded
219
  )
220
 
221
  with gr.Column(scale=1):
222
  result_output = gr.Markdown(
223
  label="Analysis Results",
224
+ value="Upload an image and click 'Analyze Brain Scan' to see results." if model_loaded else "⚠️ Model not loaded. Check status above."
225
  )
226
 
227
  # API endpoint simulation
 
229
  api_output = gr.JSON(label="API Response Format")
230
 
231
  # Event handlers
232
+ if model_loaded:
233
+ analyze_btn.click(
234
+ fn=analyze_brain_scan,
235
+ inputs=[image_input, patient_name, patient_age, symptoms],
236
+ outputs=result_output
237
+ )
238
+
239
+ analyze_btn.click(
240
+ fn=create_api_response,
241
+ inputs=[image_input, patient_name, patient_age, symptoms],
242
+ outputs=api_output
243
+ )
244
 
245
+ # Instructions
246
  gr.Markdown("""
247
  ## πŸ“‹ Usage Instructions:
248
  1. Upload a brain CT scan image (JPEG or PNG)
 
252
 
253
  ## πŸ”— Integration:
254
  This interface can be integrated with your medical app using the Gradio API.
255
+
256
+ ## βœ… Based on Working Colab Code:
257
+ This version uses the exact same model loading and inference code as your working Google Colab notebook.
258
  """)
259
 
260
  if __name__ == "__main__":
 
262
  server_name="0.0.0.0",
263
  server_port=7860,
264
  share=True
265
+ )