File size: 3,627 Bytes
2f9ea03
bb16e72
2f9ea03
bb16e72
d58d5be
14f626b
8e2bfc0
ed1226e
bb16e72
ac1cd8a
 
8e2bfc0
bb16e72
 
 
ac1cd8a
 
 
 
 
bb16e72
ac1cd8a
bb16e72
 
ac1cd8a
 
 
bb16e72
 
ac1cd8a
bb16e72
 
ac1cd8a
bb16e72
 
 
 
ab9c414
bb16e72
8e2bfc0
bb16e72
8e2bfc0
ac1cd8a
ab9c414
8e2bfc0
ac1cd8a
bb16e72
 
 
ac1cd8a
8e2bfc0
 
 
ac1cd8a
bb16e72
ab9c414
8e2bfc0
ac1cd8a
 
 
8e2bfc0
bb16e72
ac1cd8a
bb16e72
 
8e2bfc0
bb16e72
ab9c414
 
ac1cd8a
 
8e2bfc0
bb16e72
ac1cd8a
 
 
8e2bfc0
ab9c414
08137ac
bb16e72
ab9c414
 
 
08137ac
ab9c414
e6713e2
ab9c414
ac1cd8a
 
 
 
ab9c414
 
e6713e2
ac1cd8a
bb16e72
 
 
ab9c414
 
 
 
 
 
8e2bfc0
 
ac1cd8a
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
from janus.models import MultiModalityCausalLM, VLChatProcessor
from PIL import Image
from diffusers import AutoencoderKL
import numpy as np
import gradio as gr

# Configure device
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32
print(f"Using device: {device}")

# Initialize medical imaging components
def load_medical_models():
    try:
        # Load processor with medical-specific configuration
        processor = VLChatProcessor.from_pretrained(
            "deepseek-ai/Janus-1.3B",
            medical_mode=True
        )
        
        # Load model with CPU/GPU optimization
        model = MultiModalityCausalLM.from_pretrained(
            "deepseek-ai/Janus-1.3B",
            torch_dtype=torch_dtype,
            attn_implementation="eager",  # Force standard attention
            low_cpu_mem_usage=True
        ).to(device).eval()
        
        # Load VAE with reduced precision
        vae = AutoencoderKL.from_pretrained(
            "stabilityai/sdxl-vae",
            torch_dtype=torch_dtype
        ).to(device).eval()
        
        return processor, model, vae
    except Exception as e:
        print(f"Error loading medical models: {str(e)}")
        raise

processor, model, vae = load_medical_models()

# Medical image analysis function
def medical_analysis(image, question, seed=42):
    try:
        # Set random seed for reproducibility
        torch.manual_seed(seed)
        np.random.seed(seed)
        
        # Convert and validate input image
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image).convert("RGB")
            
        # Prepare medical-specific input
        inputs = processor(
            text=f"<medical_query>{question}</medical_query>",
            images=[image],
            return_tensors="pt",
            max_length=512,
            truncation=True
        ).to(device)
        
        # Generate medical analysis
        outputs = model.generate(
            inputs.input_ids,
            attention_mask=inputs.attention_mask,
            max_new_tokens=512,
            temperature=0.1,
            top_p=0.95,
            pad_token_id=processor.tokenizer.eos_token_id,
            do_sample=True
        )
        
        # Clean and return medical report
        report = processor.decode(outputs[0], skip_special_tokens=True)
        return report.replace("##MEDICAL_REPORT##", "").strip()
    except Exception as e:
        return f"Radiology analysis error: {str(e)}"

# Medical interface
with gr.Blocks(title="Medical Imaging Assistant", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""# AI Radiology Assistant
                **CT/MRI/X-ray Analysis System**""")
    
    with gr.Tab("Diagnostic Imaging"):
        with gr.Row():
            med_image = gr.Image(label="DICOM Image", type="pil")
            med_question = gr.Textbox(
                label="Clinical Query", 
                placeholder="Describe findings in this CT scan..."
            )
        analysis_btn = gr.Button("Analyze", variant="primary")
        report_output = gr.Textbox(label="Radiology Report", interactive=False)
    
    # Connect components
    med_question.submit(
        medical_analysis,
        inputs=[med_image, med_question],
        outputs=report_output
    )
    analysis_btn.click(
        medical_analysis,
        inputs=[med_image, med_question],
        outputs=report_output
    )

# Launch with CPU optimization
demo.launch(
    server_name="0.0.0.0",
    server_port=7860,
    enable_queue=True,
    max_threads=2
)