Spaces:
Sleeping
Sleeping
should work
Browse files
app.py
CHANGED
@@ -16,44 +16,67 @@ logger = logging.getLogger(__name__)
|
|
16 |
model = None
|
17 |
processor = None
|
18 |
device = None
|
|
|
19 |
|
20 |
def load_model():
|
21 |
-
"""Load the AI model
|
22 |
-
global model, processor, device
|
23 |
|
24 |
-
logger.info("Loading AI model...")
|
25 |
-
|
26 |
-
# Get Hugging Face token from environment
|
27 |
-
hf_token = os.environ.get('HF_TOKEN')
|
28 |
|
|
|
29 |
model_id = "mychen76/paligemma-3b-mix-448-med_30k-ct-brain"
|
30 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
31 |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
32 |
|
33 |
logger.info(f"Using device: {device}")
|
|
|
|
|
34 |
|
35 |
try:
|
36 |
-
# Load
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
)
|
41 |
model = PaliGemmaForConditionalGeneration.from_pretrained(
|
42 |
model_id,
|
43 |
-
torch_dtype=dtype
|
44 |
-
token=hf_token
|
45 |
).to(device).eval()
|
46 |
|
47 |
logger.info("Model loaded successfully!")
|
|
|
48 |
return True
|
49 |
|
50 |
except Exception as e:
|
51 |
logger.error(f"Error loading model: {e}")
|
|
|
|
|
52 |
return False
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
def analyze_brain_scan(image, patient_name="", patient_age="", symptoms=""):
|
55 |
"""Analyze brain scan image and return medical findings"""
|
56 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
if image is None:
|
58 |
return "Please upload a brain scan image."
|
59 |
|
@@ -61,22 +84,8 @@ def analyze_brain_scan(image, patient_name="", patient_age="", symptoms=""):
|
|
61 |
if not isinstance(image, Image.Image):
|
62 |
image = Image.fromarray(image).convert("RGB")
|
63 |
|
64 |
-
# Run AI inference
|
65 |
-
|
66 |
-
inputs = processor(
|
67 |
-
images=image,
|
68 |
-
text=prompt,
|
69 |
-
return_tensors="pt"
|
70 |
-
).to(device, dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
|
71 |
-
|
72 |
-
with torch.no_grad():
|
73 |
-
generated_ids = model.generate(**inputs, max_new_tokens=100)
|
74 |
-
|
75 |
-
result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
76 |
-
|
77 |
-
# Clean up the result
|
78 |
-
if result.startswith(prompt):
|
79 |
-
result = result[len(prompt):].strip()
|
80 |
|
81 |
# Format the response
|
82 |
timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
|
@@ -106,6 +115,9 @@ Always consult with qualified medical professionals for actual diagnosis.
|
|
106 |
def create_api_response(image, patient_name="", patient_age="", symptoms=""):
|
107 |
"""Create API-compatible response for integration"""
|
108 |
try:
|
|
|
|
|
|
|
109 |
if image is None:
|
110 |
return {"error": "No image provided"}
|
111 |
|
@@ -113,24 +125,10 @@ def create_api_response(image, patient_name="", patient_age="", symptoms=""):
|
|
113 |
if not isinstance(image, Image.Image):
|
114 |
image = Image.fromarray(image).convert("RGB")
|
115 |
|
116 |
-
# Run AI inference
|
117 |
-
|
118 |
-
inputs = processor(
|
119 |
-
images=image,
|
120 |
-
text=prompt,
|
121 |
-
return_tensors="pt"
|
122 |
-
).to(device, dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
|
123 |
|
124 |
-
|
125 |
-
generated_ids = model.generate(**inputs, max_new_tokens=100)
|
126 |
-
|
127 |
-
result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
128 |
-
|
129 |
-
# Clean up the result
|
130 |
-
if result.startswith(prompt):
|
131 |
-
result = result[len(prompt):].strip()
|
132 |
-
|
133 |
-
# Create API response
|
134 |
response = {
|
135 |
"prediction": result,
|
136 |
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
@@ -141,7 +139,8 @@ def create_api_response(image, patient_name="", patient_age="", symptoms=""):
|
|
141 |
},
|
142 |
"model_info": {
|
143 |
"model_id": "mychen76/paligemma-3b-mix-448-med_30k-ct-brain",
|
144 |
-
"device": str(device)
|
|
|
145 |
}
|
146 |
}
|
147 |
|
@@ -151,9 +150,23 @@ def create_api_response(image, patient_name="", patient_age="", symptoms=""):
|
|
151 |
logger.error(f"API error: {e}")
|
152 |
return {"error": f"Analysis failed: {str(e)}"}
|
153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
# Load model at startup
|
155 |
-
logger.info("Initializing Brain CT Analyzer...")
|
156 |
-
|
|
|
157 |
logger.info("Model loaded successfully!")
|
158 |
else:
|
159 |
logger.error("Failed to load model!")
|
@@ -169,6 +182,12 @@ with gr.Blocks(title="Brain CT Analyzer", theme=gr.themes.Soft()) as demo:
|
|
169 |
**β οΈ Important:** This is for educational/research purposes only. Always consult qualified medical professionals.
|
170 |
""")
|
171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
with gr.Row():
|
173 |
with gr.Column(scale=1):
|
174 |
image_input = gr.Image(
|
@@ -195,13 +214,14 @@ with gr.Blocks(title="Brain CT Analyzer", theme=gr.themes.Soft()) as demo:
|
|
195 |
analyze_btn = gr.Button(
|
196 |
"π Analyze Brain Scan",
|
197 |
variant="primary",
|
198 |
-
size="lg"
|
|
|
199 |
)
|
200 |
|
201 |
with gr.Column(scale=1):
|
202 |
result_output = gr.Markdown(
|
203 |
label="Analysis Results",
|
204 |
-
value="Upload an image and click 'Analyze Brain Scan' to see results."
|
205 |
)
|
206 |
|
207 |
# API endpoint simulation
|
@@ -209,19 +229,20 @@ with gr.Blocks(title="Brain CT Analyzer", theme=gr.themes.Soft()) as demo:
|
|
209 |
api_output = gr.JSON(label="API Response Format")
|
210 |
|
211 |
# Event handlers
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
|
|
223 |
|
224 |
-
#
|
225 |
gr.Markdown("""
|
226 |
## π Usage Instructions:
|
227 |
1. Upload a brain CT scan image (JPEG or PNG)
|
@@ -231,6 +252,9 @@ with gr.Blocks(title="Brain CT Analyzer", theme=gr.themes.Soft()) as demo:
|
|
231 |
|
232 |
## π Integration:
|
233 |
This interface can be integrated with your medical app using the Gradio API.
|
|
|
|
|
|
|
234 |
""")
|
235 |
|
236 |
if __name__ == "__main__":
|
@@ -238,4 +262,4 @@ if __name__ == "__main__":
|
|
238 |
server_name="0.0.0.0",
|
239 |
server_port=7860,
|
240 |
share=True
|
241 |
-
)
|
|
|
16 |
model = None
|
17 |
processor = None
|
18 |
device = None
|
19 |
+
model_loaded = False
|
20 |
|
21 |
def load_model():
|
22 |
+
"""Load the AI model exactly like in Colab"""
|
23 |
+
global model, processor, device, model_loaded
|
24 |
|
25 |
+
logger.info("Loading AI model (Colab style)...")
|
|
|
|
|
|
|
26 |
|
27 |
+
# === Load AI Model === (exactly like Colab)
|
28 |
model_id = "mychen76/paligemma-3b-mix-448-med_30k-ct-brain"
|
29 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
31 |
|
32 |
logger.info(f"Using device: {device}")
|
33 |
+
logger.info(f"Using dtype: {dtype}")
|
34 |
+
logger.info(f"CUDA available: {torch.cuda.is_available()}")
|
35 |
|
36 |
try:
|
37 |
+
# Load exactly like Colab (no token, no trust_remote_code)
|
38 |
+
logger.info("Loading processor...")
|
39 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
40 |
+
|
41 |
+
logger.info("Loading model...")
|
42 |
model = PaliGemmaForConditionalGeneration.from_pretrained(
|
43 |
model_id,
|
44 |
+
torch_dtype=dtype
|
|
|
45 |
).to(device).eval()
|
46 |
|
47 |
logger.info("Model loaded successfully!")
|
48 |
+
model_loaded = True
|
49 |
return True
|
50 |
|
51 |
except Exception as e:
|
52 |
logger.error(f"Error loading model: {e}")
|
53 |
+
logger.error(f"Error type: {type(e)}")
|
54 |
+
model_loaded = False
|
55 |
return False
|
56 |
|
57 |
+
def run_model(img):
|
58 |
+
"""Run model inference exactly like Colab"""
|
59 |
+
prompt = "<image> Findings:"
|
60 |
+
inputs = processor(images=img, text=prompt, return_tensors="pt").to(device, dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
|
61 |
+
generated_ids = model.generate(**inputs, max_new_tokens=100)
|
62 |
+
result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
63 |
+
return result
|
64 |
+
|
65 |
def analyze_brain_scan(image, patient_name="", patient_age="", symptoms=""):
|
66 |
"""Analyze brain scan image and return medical findings"""
|
67 |
try:
|
68 |
+
if not model_loaded or model is None:
|
69 |
+
return """
|
70 |
+
## β οΈ Model Loading Error
|
71 |
+
|
72 |
+
The AI model is not available. This could be due to:
|
73 |
+
- Model loading issues
|
74 |
+
- Memory limitations
|
75 |
+
- Network connectivity
|
76 |
+
|
77 |
+
Please check the logs or try refreshing.
|
78 |
+
"""
|
79 |
+
|
80 |
if image is None:
|
81 |
return "Please upload a brain scan image."
|
82 |
|
|
|
84 |
if not isinstance(image, Image.Image):
|
85 |
image = Image.fromarray(image).convert("RGB")
|
86 |
|
87 |
+
# Run AI inference using Colab method
|
88 |
+
result = run_model(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
# Format the response
|
91 |
timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
115 |
def create_api_response(image, patient_name="", patient_age="", symptoms=""):
|
116 |
"""Create API-compatible response for integration"""
|
117 |
try:
|
118 |
+
if not model_loaded or model is None:
|
119 |
+
return {"error": "Model not loaded"}
|
120 |
+
|
121 |
if image is None:
|
122 |
return {"error": "No image provided"}
|
123 |
|
|
|
125 |
if not isinstance(image, Image.Image):
|
126 |
image = Image.fromarray(image).convert("RGB")
|
127 |
|
128 |
+
# Run AI inference using Colab method
|
129 |
+
result = run_model(image)
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
+
# Create API response (matching your original format)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
response = {
|
133 |
"prediction": result,
|
134 |
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
|
139 |
},
|
140 |
"model_info": {
|
141 |
"model_id": "mychen76/paligemma-3b-mix-448-med_30k-ct-brain",
|
142 |
+
"device": str(device),
|
143 |
+
"model_loaded": model_loaded
|
144 |
}
|
145 |
}
|
146 |
|
|
|
150 |
logger.error(f"API error: {e}")
|
151 |
return {"error": f"Analysis failed: {str(e)}"}
|
152 |
|
153 |
+
def get_model_status():
|
154 |
+
"""Get current model status"""
|
155 |
+
return f"""
|
156 |
+
## π€ Model Status
|
157 |
+
|
158 |
+
- **Model Loaded**: {model_loaded}
|
159 |
+
- **Device**: {device}
|
160 |
+
- **CUDA Available**: {torch.cuda.is_available()}
|
161 |
+
- **Model Object**: {type(model).__name__ if model else 'None'}
|
162 |
+
- **Processor Object**: {type(processor).__name__ if processor else 'None'}
|
163 |
+
- **PyTorch Version**: {torch.__version__}
|
164 |
+
"""
|
165 |
+
|
166 |
# Load model at startup
|
167 |
+
logger.info("Initializing Brain CT Analyzer (Colab Style)...")
|
168 |
+
load_success = load_model()
|
169 |
+
if load_success:
|
170 |
logger.info("Model loaded successfully!")
|
171 |
else:
|
172 |
logger.error("Failed to load model!")
|
|
|
182 |
**β οΈ Important:** This is for educational/research purposes only. Always consult qualified medical professionals.
|
183 |
""")
|
184 |
|
185 |
+
# Model status section
|
186 |
+
with gr.Accordion("π§ Model Status", open=not model_loaded):
|
187 |
+
status_output = gr.Markdown(value=get_model_status())
|
188 |
+
refresh_btn = gr.Button("π Refresh Status")
|
189 |
+
refresh_btn.click(fn=get_model_status, outputs=status_output)
|
190 |
+
|
191 |
with gr.Row():
|
192 |
with gr.Column(scale=1):
|
193 |
image_input = gr.Image(
|
|
|
214 |
analyze_btn = gr.Button(
|
215 |
"π Analyze Brain Scan",
|
216 |
variant="primary",
|
217 |
+
size="lg",
|
218 |
+
interactive=model_loaded
|
219 |
)
|
220 |
|
221 |
with gr.Column(scale=1):
|
222 |
result_output = gr.Markdown(
|
223 |
label="Analysis Results",
|
224 |
+
value="Upload an image and click 'Analyze Brain Scan' to see results." if model_loaded else "β οΈ Model not loaded. Check status above."
|
225 |
)
|
226 |
|
227 |
# API endpoint simulation
|
|
|
229 |
api_output = gr.JSON(label="API Response Format")
|
230 |
|
231 |
# Event handlers
|
232 |
+
if model_loaded:
|
233 |
+
analyze_btn.click(
|
234 |
+
fn=analyze_brain_scan,
|
235 |
+
inputs=[image_input, patient_name, patient_age, symptoms],
|
236 |
+
outputs=result_output
|
237 |
+
)
|
238 |
+
|
239 |
+
analyze_btn.click(
|
240 |
+
fn=create_api_response,
|
241 |
+
inputs=[image_input, patient_name, patient_age, symptoms],
|
242 |
+
outputs=api_output
|
243 |
+
)
|
244 |
|
245 |
+
# Instructions
|
246 |
gr.Markdown("""
|
247 |
## π Usage Instructions:
|
248 |
1. Upload a brain CT scan image (JPEG or PNG)
|
|
|
252 |
|
253 |
## π Integration:
|
254 |
This interface can be integrated with your medical app using the Gradio API.
|
255 |
+
|
256 |
+
## β
Based on Working Colab Code:
|
257 |
+
This version uses the exact same model loading and inference code as your working Google Colab notebook.
|
258 |
""")
|
259 |
|
260 |
if __name__ == "__main__":
|
|
|
262 |
server_name="0.0.0.0",
|
263 |
server_port=7860,
|
264 |
share=True
|
265 |
+
)
|