Spaces:
Runtime error
Runtime error
import torch | |
from janus.models import MultiModalityCausalLM, VLChatProcessor | |
from PIL import Image | |
from diffusers import AutoencoderKL | |
import numpy as np | |
import gradio as gr | |
# Configure device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# Initialize medical imaging components | |
def load_medical_models(): | |
try: | |
# Load processor and tokenizer | |
processor = VLChatProcessor.from_pretrained("deepseek-ai/Janus-1.3B") | |
# Load base model | |
model = MultiModalityCausalLM.from_pretrained( | |
"deepseek-ai/Janus-1.3B", | |
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32 | |
).to(device).eval() | |
# Load VAE for image processing | |
vae = AutoencoderKL.from_pretrained( | |
"stabilityai/sdxl-vae", | |
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32 | |
).to(device).eval() | |
return processor, model, vae | |
except Exception as e: | |
print(f"Error loading models: {str(e)}") | |
raise | |
processor, model, vae = load_medical_models() | |
# Medical image analysis function | |
def medical_analysis(image, question, seed=42, top_p=0.95, temperature=0.1): | |
try: | |
# Set random seed for reproducibility | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
# Prepare inputs | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image).convert("RGB") | |
inputs = processor( | |
text=question, | |
images=[image], | |
return_tensors="pt" | |
).to(device) | |
# Generate analysis | |
outputs = model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=512, | |
temperature=temperature, | |
top_p=top_p | |
) | |
return processor.decode(outputs[0], skip_special_tokens=True) | |
except Exception as e: | |
return f"Analysis error: {str(e)}" | |
# Medical interface | |
with gr.Blocks(title="Medical Imaging Assistant") as demo: | |
gr.Markdown("# Medical Imaging AI Assistant") | |
with gr.Tab("Analysis"): | |
with gr.Row(): | |
med_image = gr.Image(label="Input Image", type="pil") | |
med_question = gr.Textbox(label="Clinical Query") | |
analysis_output = gr.Textbox(label="Findings") | |
gr.Examples( | |
examples=[ | |
["ultrasound_sample.jpg", "Identify any abnormalities in this ultrasound"], | |
["xray_sample.jpg", "Describe the bone structure visible in this X-ray"] | |
], | |
inputs=[med_image, med_question] | |
) | |
med_question.submit( | |
medical_analysis, | |
inputs=[med_image, med_question], | |
outputs=analysis_output | |
) | |
demo.launch() |