Spaces:
Runtime error
Runtime error
File size: 3,627 Bytes
2f9ea03 bb16e72 2f9ea03 bb16e72 d58d5be 14f626b 8e2bfc0 ed1226e bb16e72 ac1cd8a 8e2bfc0 bb16e72 ac1cd8a bb16e72 ac1cd8a bb16e72 ac1cd8a bb16e72 ac1cd8a bb16e72 ac1cd8a bb16e72 ab9c414 bb16e72 8e2bfc0 bb16e72 8e2bfc0 ac1cd8a ab9c414 8e2bfc0 ac1cd8a bb16e72 ac1cd8a 8e2bfc0 ac1cd8a bb16e72 ab9c414 8e2bfc0 ac1cd8a 8e2bfc0 bb16e72 ac1cd8a bb16e72 8e2bfc0 bb16e72 ab9c414 ac1cd8a 8e2bfc0 bb16e72 ac1cd8a 8e2bfc0 ab9c414 08137ac bb16e72 ab9c414 08137ac ab9c414 e6713e2 ab9c414 ac1cd8a ab9c414 e6713e2 ac1cd8a bb16e72 ab9c414 8e2bfc0 ac1cd8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import torch
from janus.models import MultiModalityCausalLM, VLChatProcessor
from PIL import Image
from diffusers import AutoencoderKL
import numpy as np
import gradio as gr
# Configure device
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32
print(f"Using device: {device}")
# Initialize medical imaging components
def load_medical_models():
try:
# Load processor with medical-specific configuration
processor = VLChatProcessor.from_pretrained(
"deepseek-ai/Janus-1.3B",
medical_mode=True
)
# Load model with CPU/GPU optimization
model = MultiModalityCausalLM.from_pretrained(
"deepseek-ai/Janus-1.3B",
torch_dtype=torch_dtype,
attn_implementation="eager", # Force standard attention
low_cpu_mem_usage=True
).to(device).eval()
# Load VAE with reduced precision
vae = AutoencoderKL.from_pretrained(
"stabilityai/sdxl-vae",
torch_dtype=torch_dtype
).to(device).eval()
return processor, model, vae
except Exception as e:
print(f"Error loading medical models: {str(e)}")
raise
processor, model, vae = load_medical_models()
# Medical image analysis function
def medical_analysis(image, question, seed=42):
try:
# Set random seed for reproducibility
torch.manual_seed(seed)
np.random.seed(seed)
# Convert and validate input image
if isinstance(image, np.ndarray):
image = Image.fromarray(image).convert("RGB")
# Prepare medical-specific input
inputs = processor(
text=f"<medical_query>{question}</medical_query>",
images=[image],
return_tensors="pt",
max_length=512,
truncation=True
).to(device)
# Generate medical analysis
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=512,
temperature=0.1,
top_p=0.95,
pad_token_id=processor.tokenizer.eos_token_id,
do_sample=True
)
# Clean and return medical report
report = processor.decode(outputs[0], skip_special_tokens=True)
return report.replace("##MEDICAL_REPORT##", "").strip()
except Exception as e:
return f"Radiology analysis error: {str(e)}"
# Medical interface
with gr.Blocks(title="Medical Imaging Assistant", theme=gr.themes.Soft()) as demo:
gr.Markdown("""# AI Radiology Assistant
**CT/MRI/X-ray Analysis System**""")
with gr.Tab("Diagnostic Imaging"):
with gr.Row():
med_image = gr.Image(label="DICOM Image", type="pil")
med_question = gr.Textbox(
label="Clinical Query",
placeholder="Describe findings in this CT scan..."
)
analysis_btn = gr.Button("Analyze", variant="primary")
report_output = gr.Textbox(label="Radiology Report", interactive=False)
# Connect components
med_question.submit(
medical_analysis,
inputs=[med_image, med_question],
outputs=report_output
)
analysis_btn.click(
medical_analysis,
inputs=[med_image, med_question],
outputs=report_output
)
# Launch with CPU optimization
demo.launch(
server_name="0.0.0.0",
server_port=7860,
enable_queue=True,
max_threads=2
) |