Spaces:
Runtime error
Runtime error
File size: 6,525 Bytes
39d4c85 2f9ea03 39d4c85 bb16e72 39d4c85 2f9ea03 d58d5be 39d4c85 8e2bfc0 39d4c85 8e2bfc0 39d4c85 8e2bfc0 39d4c85 8e2bfc0 39d4c85 bb16e72 39d4c85 8e2bfc0 39d4c85 bb16e72 39d4c85 8e2bfc0 08137ac 39d4c85 e6713e2 39d4c85 ab9c414 39d4c85 8e2bfc0 39d4c85 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import gradio as gr
import torch
from transformers import AutoConfig, AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor
from janus.utils.io import load_pil_images
from PIL import Image
import numpy as np
import os
import time
import spaces
# Load medical imaging-optimized model and processor
model_path = "deepseek-ai/Janus-Pro-1B"
config = AutoConfig.from_pretrained(model_path)
language_config = config.language_config
language_config._attn_implementation = 'eager'
# Initialize model with medical imaging parameters
vl_gpt = AutoModelForCausalLM.from_pretrained(
model_path,
language_config=language_config,
trust_remote_code=True,
medical_head=True # Assuming custom medical imaging head
).to(torch.bfloat16 if torch.cuda.is_available() else torch.float16)
if torch.cuda.is_available():
vl_gpt = vl_gpt.cuda()
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
@torch.inference_mode()
@spaces.GPU(duration=120)
def medical_image_analysis(medical_image, clinical_question, seed, top_p, temperature):
"""Analyze medical images (CT, MRI, X-ray, histopathology) with clinical context."""
torch.cuda.empty_cache()
torch.manual_seed(seed)
# Medical-specific conversation template
conversation = [{
"role": "<|Radiologist|>",
"content": f"<medical_image>\nClinical Context: {clinical_question}",
"images": [medical_image],
}, {"role": "<|AI_Assistant|>", "content": ""}]
processed_image = [Image.fromarray(medical_image)]
inputs = vl_chat_processor(
conversations=conversation,
images=processed_image,
force_batchify=True
).to(cuda_device, dtype=torch.bfloat16)
inputs_embeds = vl_gpt.prepare_inputs_embeds(**inputs)
# Medical-optimized generation parameters
outputs = vl_gpt.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=inputs.attention_mask,
max_new_tokens=512,
temperature=0.2, # Lower for clinical precision
top_p=0.9,
repetition_penalty=1.2, # Reduce hallucination
medical_mode=True
)
findings = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
return f"Clinical Findings:\n{findings}"
@torch.inference_mode()
@spaces.GPU(duration=120)
def generate_medical_image(prompt, seed=None, guidance=5, t2i_temperature=0.5):
"""Generate synthetic medical images for educational/research purposes."""
torch.cuda.empty_cache()
if seed is not None:
torch.manual_seed(seed)
# Medical image generation parameters
medical_config = {
'width': 512,
'height': 512,
'parallel_size': 3,
'modality': 'mri', # Can specify CT, X-ray, etc.
'anatomy': 'brain' # Target anatomy
}
messages = [{
'role': '<|Clinician|>',
'content': f"{prompt} [Modality: {medical_config['modality']}, Anatomy: {medical_config['anatomy']}]"
}]
text = vl_chat_processor.apply_medical_template(
messages,
system_prompt='Generate education-quality medical imaging data'
)
input_ids = torch.LongTensor(tokenizer.encode(text)).to(cuda_device)
generated_tokens, patches = vl_gpt.generate_medical_image(
input_ids,
**medical_config,
cfg_weight=guidance,
temperature=t2i_temperature
)
# Post-processing for medical imaging standards
synthetic_images = postprocess_medical_images(patches, **medical_config)
return [Image.fromarray(img).resize((512, 512)) for img in synthetic_images]
# Medical-optimized Gradio interface
with gr.Blocks(title="Medical Imaging AI Suite") as demo:
gr.Markdown("""## Medical Image Analysis Suite v2.1
*For research use only - not for clinical diagnosis*""")
with gr.Tab("Clinical Image Analysis"):
with gr.Row():
medical_image_input = gr.Image(label="Upload Medical Scan")
clinical_question = gr.Textbox(label="Clinical Query",
placeholder="E.g.: 'Assess tumor progression in this MRI series'")
with gr.Accordion("Advanced Parameters", open=False):
und_seed = gr.Number(42, label="Reproducibility Seed")
analysis_top_p = gr.Slider(0.8, 1.0, 0.95, label="Diagnostic Certainty")
analysis_temp = gr.Slider(0.1, 0.5, 0.2, label="Analysis Precision")
analysis_btn = gr.Button("Analyze Scan", variant="primary")
clinical_report = gr.Textbox(label="AI Analysis Report", interactive=False)
gr.Examples(
examples=[
["Identify pulmonary nodules in this CT scan", "ct_chest.png"],
["Assess MRI for multiple sclerosis lesions", "brain_mri.jpg"],
["Histopathology analysis: tumor grading", "biopsy_slide.png"]
],
inputs=[clinical_question, medical_image_input]
)
with gr.Tab("Medical Imaging Synthesis"):
gr.Markdown("**Educational Image Generation**")
synth_prompt = gr.Textbox(label="Synthesis Prompt",
placeholder="E.g.: 'Synthetic brain MRI showing glioblastoma multiforme'")
with gr.Row():
synth_guidance = gr.Slider(3, 7, 5, label="Anatomical Accuracy")
synth_temp = gr.Slider(0.3, 1.0, 0.6, label="Synthesis Variability")
synth_btn = gr.Button("Generate Educational Images", variant="secondary")
synthetic_gallery = gr.Gallery(label="Synthetic Medical Images",
columns=3, object_fit="contain")
gr.Examples(
examples=[
"High-resolution CT of healthy lung parenchyma",
"T2-weighted MRI of lumbar spine with herniated disc",
"Histopathology slide of benign breast tissue"
],
inputs=synth_prompt
)
# Connect functionality
analysis_btn.click(
medical_image_analysis,
inputs=[medical_image_input, clinical_question, und_seed, analysis_top_p, analysis_temp],
outputs=clinical_report
)
synth_btn.click(
generate_medical_image,
inputs=[synth_prompt, und_seed, synth_guidance, synth_temp],
outputs=synthetic_gallery
)
demo.launch(share=True, server_port=7860) |