Spaces:
Running
Running
import os | |
import torch | |
import time | |
import gradio as gr | |
import requests | |
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration | |
from peft import PeftModel | |
from PIL import Image | |
from io import BytesIO | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Global variables for model | |
model = None | |
processor = None | |
device = None | |
model_loaded = False | |
def load_model(): | |
"""Load the AI model with PEFT adapter (Colab style)""" | |
global model, processor, device, model_loaded | |
logger.info("Loading AI model with PEFT adapter (Colab style)...") | |
# === Load AI Model === (base model + adapter) | |
base_model_id = "google/paligemma-3b-mix-448" | |
adapter_model_id = "mychen76/paligemma-3b-mix-448-med_30k-ct-brain" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
logger.info(f"Using device: {device}") | |
logger.info(f"Using dtype: {dtype}") | |
logger.info(f"CUDA available: {torch.cuda.is_available()}") | |
logger.info(f"Base model: {base_model_id}") | |
logger.info(f"Adapter model: {adapter_model_id}") | |
try: | |
# Load processor from base model | |
logger.info("Loading processor...") | |
processor = AutoProcessor.from_pretrained(base_model_id) | |
# Load base model | |
logger.info("Loading base model...") | |
model = PaliGemmaForConditionalGeneration.from_pretrained( | |
base_model_id, | |
torch_dtype=dtype, | |
device_map="auto" if torch.cuda.is_available() else None | |
) | |
# Load PEFT adapter | |
logger.info("Loading PEFT adapter...") | |
model = PeftModel.from_pretrained(model, adapter_model_id) | |
# Set to eval mode | |
model.eval() | |
# Move to device if not using device_map | |
if not torch.cuda.is_available(): | |
model = model.to(device) | |
logger.info("Model loaded successfully!") | |
model_loaded = True | |
return True | |
except Exception as e: | |
logger.error(f"Error loading model: {e}") | |
logger.error(f"Error type: {type(e)}") | |
# If license error, provide helpful message | |
if "license" in str(e).lower() or "access" in str(e).lower(): | |
logger.error("This appears to be a license/access issue with the base model.") | |
logger.error("You may need to:") | |
logger.error("1. Accept the license for google/paligemma-3b-mix-448 on HuggingFace") | |
logger.error("2. Login with: huggingface-cli login") | |
logger.error("3. Use your HuggingFace token") | |
model_loaded = False | |
return False | |
def run_model(img): | |
"""Run model inference exactly like Colab""" | |
prompt = "<image> Findings:" | |
inputs = processor(images=img, text=prompt, return_tensors="pt").to(device, dtype=torch.float16 if torch.cuda.is_available() else torch.float32) | |
generated_ids = model.generate(**inputs, max_new_tokens=100) | |
result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
return result | |
def analyze_brain_scan(image, patient_name="", patient_age="", symptoms=""): | |
"""Analyze brain scan image and return medical findings""" | |
try: | |
logger.info(f"=== ANALYZE FUNCTION CALLED ===") | |
logger.info(f"Image received: {image is not None}") | |
logger.info(f"Model loaded: {model_loaded}") | |
logger.info(f"Model object: {model is not None}") | |
if not model_loaded or model is None: | |
error_msg = """ | |
## β οΈ Model Loading Error | |
The AI model is not available. This could be due to: | |
- **License Issue**: The base model requires accepting Google's license | |
- **PEFT Loading Issue**: Problem loading the medical adapter | |
- **Memory limitations**: Insufficient resources | |
- **Network connectivity**: Download issues | |
**To fix this:** | |
1. Accept the license for `google/paligemma-3b-mix-448` on HuggingFace | |
2. Login with your HuggingFace token: `huggingface-cli login` | |
3. Restart the application | |
Please check the logs for more details. | |
""" | |
logger.error("Model not loaded - returning error message") | |
return error_msg | |
if image is None: | |
logger.warning("No image provided") | |
return "## β οΈ No Image\n\nPlease upload a brain scan image first, then click 'Analyze Brain Scan'." | |
logger.info("Converting image to PIL format...") | |
# Convert to PIL Image if needed | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image).convert("RGB") | |
logger.info("Starting AI inference...") | |
# Run AI inference using Colab method | |
result = run_model(image) | |
logger.info(f"AI inference completed. Result length: {len(result) if result else 0}") | |
# Format the response | |
timestamp = time.strftime("%Y-%m-%d %H:%M:%S") | |
formatted_result = f""" | |
## Brain CT Analysis Results | |
**Patient Information:** | |
- Name: {patient_name or 'Not provided'} | |
- Age: {patient_age or 'Not provided'} | |
- Symptoms: {symptoms or 'Not provided'} | |
- Analysis Date: {timestamp} | |
**AI Findings:** | |
{result} | |
**Model Info:** | |
- Base Model: google/paligemma-3b-mix-448 | |
- Medical Adapter: mychen76/paligemma-3b-mix-448-med_30k-ct-brain | |
- Device: {device} | |
**Note:** This is an AI-generated analysis for educational purposes only. | |
Always consult with qualified medical professionals for actual diagnosis. | |
""" | |
logger.info("Analysis completed successfully") | |
return formatted_result | |
except Exception as e: | |
logger.error(f"Analysis error: {e}") | |
logger.error(f"Error type: {type(e)}") | |
import traceback | |
logger.error(f"Traceback: {traceback.format_exc()}") | |
return f""" | |
## β Analysis Error | |
An error occurred during analysis: | |
**Error**: {str(e)} | |
**Error Type**: {type(e).__name__} | |
Please check the logs for more details and try again. | |
""" | |
def create_api_response(image, patient_name="", patient_age="", symptoms=""): | |
"""Create API-compatible response for integration""" | |
try: | |
logger.info(f"=== API RESPONSE FUNCTION CALLED ===") | |
if not model_loaded or model is None: | |
return {"error": "Model not loaded - check license and authentication"} | |
if image is None: | |
return {"error": "No image provided"} | |
# Convert to PIL Image if needed | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image).convert("RGB") | |
# Run AI inference using Colab method | |
result = run_model(image) | |
# Create API response (matching your original format) | |
response = { | |
"prediction": result, | |
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), | |
"patient_info": { | |
"name": patient_name, | |
"age": patient_age, | |
"symptoms": symptoms | |
}, | |
"model_info": { | |
"base_model": "google/paligemma-3b-mix-448", | |
"adapter_model": "mychen76/paligemma-3b-mix-448-med_30k-ct-brain", | |
"device": str(device), | |
"model_loaded": model_loaded | |
} | |
} | |
return response | |
except Exception as e: | |
logger.error(f"API error: {e}") | |
import traceback | |
logger.error(f"API Traceback: {traceback.format_exc()}") | |
return {"error": f"Analysis failed: {str(e)}"} | |
def get_model_status(): | |
"""Get current model status""" | |
return f""" | |
## π€ Model Status | |
- **Model Loaded**: {model_loaded} | |
- **Device**: {device} | |
- **CUDA Available**: {torch.cuda.is_available()} | |
- **Model Object**: {type(model).__name__ if model else 'None'} | |
- **Processor Object**: {type(processor).__name__ if processor else 'None'} | |
- **PyTorch Version**: {torch.__version__} | |
## π Model Configuration | |
- **Base Model**: google/paligemma-3b-mix-448 | |
- **Medical Adapter**: mychen76/paligemma-3b-mix-448-med_30k-ct-brain | |
- **Model Type**: PEFT/LoRA Fine-tuned | |
## β οΈ Requirements | |
- HuggingFace account with accepted license for PaliGemma | |
- HuggingFace token authentication | |
- PEFT library for adapter loading | |
""" | |
# Load model at startup | |
logger.info("Initializing Brain CT Analyzer with PEFT (Colab Style)...") | |
load_success = load_model() | |
if load_success: | |
logger.info("Model loaded successfully!") | |
else: | |
logger.error("Failed to load model!") | |
# Create Gradio interface | |
with gr.Blocks(title="Brain CT Analyzer", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# π§ Brain CT Analyzer | |
Upload a brain CT scan image for AI-powered analysis. This tool uses the PaliGemma medical model | |
with specialized medical fine-tuning to provide preliminary findings. | |
**β οΈ Important:** This is for educational/research purposes only. Always consult qualified medical professionals. | |
**π Requirements:** This model requires accepting Google's PaliGemma license and HuggingFace authentication. | |
""") | |
# Model status section | |
with gr.Accordion("π§ Model Status", open=not model_loaded): | |
status_output = gr.Markdown(value=get_model_status()) | |
refresh_btn = gr.Button("π Refresh Status") | |
refresh_btn.click(fn=get_model_status, outputs=status_output) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_input = gr.Image( | |
label="Upload Brain CT Scan", | |
type="pil", | |
height=400 | |
) | |
with gr.Group(): | |
patient_name = gr.Textbox( | |
label="Patient Name (Optional)", | |
placeholder="Enter patient name" | |
) | |
patient_age = gr.Textbox( | |
label="Patient Age (Optional)", | |
placeholder="Enter patient age" | |
) | |
symptoms = gr.Textbox( | |
label="Symptoms (Optional)", | |
placeholder="Describe symptoms", | |
lines=3 | |
) | |
analyze_btn = gr.Button( | |
"π Analyze Brain Scan", | |
variant="primary", | |
size="lg", | |
interactive=model_loaded | |
) | |
with gr.Column(scale=1): | |
result_output = gr.Markdown( | |
label="Analysis Results", | |
value="Upload an image and click 'Analyze Brain Scan' to see results." if model_loaded else "β οΈ Model not loaded. Check status above and ensure license acceptance." | |
) | |
# API endpoint simulation | |
with gr.Accordion("π API Response (for developers)", open=False): | |
api_output = gr.JSON(label="API Response Format") | |
# Test function for debugging | |
def test_function(): | |
logger.info("=== TEST BUTTON CLICKED ===") | |
return f"β Test button works! Model loaded: {model_loaded}" | |
# Add test button for debugging | |
with gr.Row(): | |
test_btn = gr.Button("π§ͺ Test Button (Debug)", variant="secondary") | |
test_output = gr.Textbox(label="Test Output", visible=True) | |
test_btn.click(fn=test_function, outputs=test_output) | |
# Event handlers - ALWAYS attach, let the function handle the logic | |
analyze_btn.click( | |
fn=analyze_brain_scan, | |
inputs=[image_input, patient_name, patient_age, symptoms], | |
outputs=result_output | |
) | |
analyze_btn.click( | |
fn=create_api_response, | |
inputs=[image_input, patient_name, patient_age, symptoms], | |
outputs=api_output | |
) | |
# Instructions | |
gr.Markdown(""" | |
## π Usage Instructions: | |
1. **Accept License**: Go to [google/paligemma-3b-mix-448](https://huggingface.co/google/paligemma-3b-mix-448) and accept the license | |
2. **Authenticate**: Login with `huggingface-cli login` using your token | |
3. Upload a brain CT scan image (JPEG or PNG) | |
4. Optionally fill in patient information | |
5. Click "Analyze Brain Scan" to get AI findings | |
6. Review the results in the output panel | |
## π Integration: | |
This interface can be integrated with your medical app using the Gradio API. | |
## β Based on Working Colab Code: | |
This version uses PEFT to load the medical fine-tuned adapter on top of the base PaliGemma model, | |
exactly matching your working Google Colab setup. | |
""") | |
if __name__ == "__main__": | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True | |
) |