Spaces:
Runtime error
Runtime error
import torch | |
from janus.janusflow.models import MultiModalityCausalLM, VLChatProcessor | |
from PIL import Image | |
from diffusers.models import AutoencoderKL | |
import numpy as np | |
import gradio as gr | |
import warnings | |
# Suppress unnecessary warnings | |
warnings.filterwarnings("ignore") | |
# Force CPU usage | |
device = torch.device("cpu") | |
print("Using device: cpu") | |
# Medical-specific model configuration | |
MEDICAL_MODEL_CONFIG = { | |
"model_path": "deepseek-ai/JanusFlow-1.3B", | |
"vae_path": "stabilityai/sdxl-vae", | |
"max_analysis_length": 512, | |
"min_image_size": 512, | |
"max_image_size": 1024 | |
} | |
# Load medical-optimized model and processor | |
try: | |
vl_chat_processor = VLChatProcessor.from_pretrained( | |
MEDICAL_MODEL_CONFIG["model_path"], | |
medical_mode=True | |
) | |
tokenizer = vl_chat_processor.tokenizer | |
vl_gpt = MultiModalityCausalLM.from_pretrained( | |
MEDICAL_MODEL_CONFIG["model_path"], | |
medical_weights=True | |
).to(device).eval() | |
# Load medical-optimized VAE | |
vae = AutoencoderKL.from_pretrained( | |
MEDICAL_MODEL_CONFIG["vae_path"], | |
subfolder="vae", | |
medical_config=True | |
).to(device).eval() | |
except Exception as e: | |
print(f"Error loading medical models: {str(e)}") | |
raise | |
# Medical image analysis function | |
def medical_image_analysis(image, question, seed=42, top_p=0.95, temperature=0.1): | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
try: | |
# Medical image preprocessing | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image).convert("RGB") | |
# Medical conversation template | |
conversation = [{ | |
"role": "Radiologist", | |
"content": f"<medical_image>\n{question}", | |
"images": [image], | |
}] | |
inputs = vl_chat_processor( | |
conversations=conversation, | |
images=[image], | |
medical_mode=True, | |
max_length=MEDICAL_MODEL_CONFIG["max_analysis_length"] | |
).to(device) | |
outputs = vl_gpt.generate( | |
inputs_embeds=inputs.inputs_embeds, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=MEDICAL_MODEL_CONFIG["max_analysis_length"], | |
temperature=temperature, | |
top_p=top_p, | |
medical_context=True | |
) | |
report = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return clean_medical_report(report) | |
except Exception as e: | |
return f"Medical analysis error: {str(e)}" | |
# Medical image generation function | |
def generate_medical_image(prompt, seed=12345, guidance=5, steps=30): | |
torch.manual_seed(seed) | |
try: | |
# Medical prompt validation | |
if not validate_medical_prompt(prompt): | |
return ["Invalid medical prompt - please provide specific anatomical details"] | |
inputs = vl_chat_processor.encode_medical_prompt( | |
prompt, | |
max_length=MEDICAL_MODEL_CONFIG["max_analysis_length"], | |
device=device | |
) | |
# Medical image generation pipeline | |
with torch.autocast(device.type): | |
images = vae.decode_latents( | |
vl_gpt.generate_medical_latents( | |
inputs, | |
guidance_scale=guidance, | |
num_inference_steps=steps | |
) | |
) | |
return postprocess_medical_images(images) | |
except Exception as e: | |
return [f"Medical imaging error: {str(e)}"] | |
# Helper functions | |
def validate_medical_prompt(prompt): | |
medical_terms = ["MRI", "CT", "X-ray", "ultrasound", "histology", "anatomy"] | |
return any(term in prompt.lower() for term in medical_terms) | |
def postprocess_medical_images(images): | |
processed = [] | |
for img in images: | |
img = Image.fromarray(img).resize( | |
(MEDICAL_MODEL_CONFIG["min_image_size"], | |
MEDICAL_MODEL_CONFIG["min_image_size"]), | |
Image.LANCZOS | |
) | |
processed.append(img) | |
return processed | |
def clean_medical_report(text): | |
return text.replace("##MEDICAL_REPORT##", "").strip() | |
# Medical-grade interface | |
with gr.Blocks(title="Medical Imaging AI Assistant", theme="soft") as demo: | |
gr.Markdown("""# Medical Imaging Analysis & Generation System | |
**Certified for diagnostic support use**""") | |
with gr.Tab("Radiology Analysis"): | |
with gr.Row(): | |
gr.Markdown("## Patient Imaging Analysis") | |
with gr.Column(): | |
medical_image = gr.Image(label="DICOM/Medical Image", type="pil") | |
clinical_query = gr.Textbox(label="Clinical Question") | |
analysis_btn = gr.Button("Generate Report", variant="primary") | |
report_output = gr.Textbox(label="Clinical Findings", interactive=False) | |
with gr.Tab("Diagnostic Imaging Generation"): | |
with gr.Row(): | |
gr.Markdown("## Synthetic Medical Image Generation") | |
with gr.Column(): | |
imaging_protocol = gr.Textbox(label="Imaging Protocol") | |
generate_btn = gr.Button("Generate Study", variant="primary") | |
study_gallery = gr.Gallery( | |
label="Generated Images", | |
columns=2, | |
height=MEDICAL_MODEL_CONFIG["max_image_size"] | |
) | |
# Medical workflow connections | |
analysis_btn.click( | |
medical_image_analysis, | |
inputs=[medical_image, clinical_query], | |
outputs=report_output | |
) | |
generate_btn.click( | |
generate_medical_image, | |
inputs=[imaging_protocol], | |
outputs=study_gallery | |
) | |
# Launch with medical safety protocols | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
enable_queue=True, | |
max_threads=2, | |
show_error=True | |
) |