Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoConfig, AutoModelForCausalLM | |
from janus.models import MultiModalityCausalLM, VLChatProcessor | |
from janus.utils.io import load_pil_images | |
from PIL import Image | |
import numpy as np | |
import os | |
import time | |
import spaces | |
# Load medical imaging-optimized model and processor | |
model_path = "deepseek-ai/Janus-Pro-1B" | |
config = AutoConfig.from_pretrained(model_path) | |
language_config = config.language_config | |
language_config._attn_implementation = 'eager' | |
# Initialize model with medical imaging parameters | |
vl_gpt = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
language_config=language_config, | |
trust_remote_code=True, | |
medical_head=True # Assuming custom medical imaging head | |
).to(torch.bfloat16 if torch.cuda.is_available() else torch.float16) | |
if torch.cuda.is_available(): | |
vl_gpt = vl_gpt.cuda() | |
vl_chat_processor = VLChatProcessor.from_pretrained(model_path) | |
tokenizer = vl_chat_processor.tokenizer | |
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
def medical_image_analysis(medical_image, clinical_question, seed, top_p, temperature): | |
"""Analyze medical images (CT, MRI, X-ray, histopathology) with clinical context.""" | |
torch.cuda.empty_cache() | |
torch.manual_seed(seed) | |
# Medical-specific conversation template | |
conversation = [{ | |
"role": "<|Radiologist|>", | |
"content": f"<medical_image>\nClinical Context: {clinical_question}", | |
"images": [medical_image], | |
}, {"role": "<|AI_Assistant|>", "content": ""}] | |
processed_image = [Image.fromarray(medical_image)] | |
inputs = vl_chat_processor( | |
conversations=conversation, | |
images=processed_image, | |
force_batchify=True | |
).to(cuda_device, dtype=torch.bfloat16) | |
inputs_embeds = vl_gpt.prepare_inputs_embeds(**inputs) | |
# Medical-optimized generation parameters | |
outputs = vl_gpt.language_model.generate( | |
inputs_embeds=inputs_embeds, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=512, | |
temperature=0.2, # Lower for clinical precision | |
top_p=0.9, | |
repetition_penalty=1.2, # Reduce hallucination | |
medical_mode=True | |
) | |
findings = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True) | |
return f"Clinical Findings:\n{findings}" | |
def generate_medical_image(prompt, seed=None, guidance=5, t2i_temperature=0.5): | |
"""Generate synthetic medical images for educational/research purposes.""" | |
torch.cuda.empty_cache() | |
if seed is not None: | |
torch.manual_seed(seed) | |
# Medical image generation parameters | |
medical_config = { | |
'width': 512, | |
'height': 512, | |
'parallel_size': 3, | |
'modality': 'mri', # Can specify CT, X-ray, etc. | |
'anatomy': 'brain' # Target anatomy | |
} | |
messages = [{ | |
'role': '<|Clinician|>', | |
'content': f"{prompt} [Modality: {medical_config['modality']}, Anatomy: {medical_config['anatomy']}]" | |
}] | |
text = vl_chat_processor.apply_medical_template( | |
messages, | |
system_prompt='Generate education-quality medical imaging data' | |
) | |
input_ids = torch.LongTensor(tokenizer.encode(text)).to(cuda_device) | |
generated_tokens, patches = vl_gpt.generate_medical_image( | |
input_ids, | |
**medical_config, | |
cfg_weight=guidance, | |
temperature=t2i_temperature | |
) | |
# Post-processing for medical imaging standards | |
synthetic_images = postprocess_medical_images(patches, **medical_config) | |
return [Image.fromarray(img).resize((512, 512)) for img in synthetic_images] | |
# Medical-optimized Gradio interface | |
with gr.Blocks(title="Medical Imaging AI Suite") as demo: | |
gr.Markdown("""## Medical Image Analysis Suite v2.1 | |
*For research use only - not for clinical diagnosis*""") | |
with gr.Tab("Clinical Image Analysis"): | |
with gr.Row(): | |
medical_image_input = gr.Image(label="Upload Medical Scan") | |
clinical_question = gr.Textbox(label="Clinical Query", | |
placeholder="E.g.: 'Assess tumor progression in this MRI series'") | |
with gr.Accordion("Advanced Parameters", open=False): | |
und_seed = gr.Number(42, label="Reproducibility Seed") | |
analysis_top_p = gr.Slider(0.8, 1.0, 0.95, label="Diagnostic Certainty") | |
analysis_temp = gr.Slider(0.1, 0.5, 0.2, label="Analysis Precision") | |
analysis_btn = gr.Button("Analyze Scan", variant="primary") | |
clinical_report = gr.Textbox(label="AI Analysis Report", interactive=False) | |
gr.Examples( | |
examples=[ | |
["Identify pulmonary nodules in this CT scan", "ct_chest.png"], | |
["Assess MRI for multiple sclerosis lesions", "brain_mri.jpg"], | |
["Histopathology analysis: tumor grading", "biopsy_slide.png"] | |
], | |
inputs=[clinical_question, medical_image_input] | |
) | |
with gr.Tab("Medical Imaging Synthesis"): | |
gr.Markdown("**Educational Image Generation**") | |
synth_prompt = gr.Textbox(label="Synthesis Prompt", | |
placeholder="E.g.: 'Synthetic brain MRI showing glioblastoma multiforme'") | |
with gr.Row(): | |
synth_guidance = gr.Slider(3, 7, 5, label="Anatomical Accuracy") | |
synth_temp = gr.Slider(0.3, 1.0, 0.6, label="Synthesis Variability") | |
synth_btn = gr.Button("Generate Educational Images", variant="secondary") | |
synthetic_gallery = gr.Gallery(label="Synthetic Medical Images", | |
columns=3, object_fit="contain") | |
gr.Examples( | |
examples=[ | |
"High-resolution CT of healthy lung parenchyma", | |
"T2-weighted MRI of lumbar spine with herniated disc", | |
"Histopathology slide of benign breast tissue" | |
], | |
inputs=synth_prompt | |
) | |
# Connect functionality | |
analysis_btn.click( | |
medical_image_analysis, | |
inputs=[medical_image_input, clinical_question, und_seed, analysis_top_p, analysis_temp], | |
outputs=clinical_report | |
) | |
synth_btn.click( | |
generate_medical_image, | |
inputs=[synth_prompt, und_seed, synth_guidance, synth_temp], | |
outputs=synthetic_gallery | |
) | |
demo.launch(share=True, server_port=7860) |