File size: 2,859 Bytes
2f9ea03
bb16e72
2f9ea03
bb16e72
d58d5be
14f626b
8e2bfc0
bb16e72
 
 
8e2bfc0
bb16e72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e2bfc0
bb16e72
8e2bfc0
 
bb16e72
8e2bfc0
bb16e72
 
 
 
 
8e2bfc0
 
 
bb16e72
 
8e2bfc0
bb16e72
8e2bfc0
bb16e72
 
 
 
8e2bfc0
bb16e72
8e2bfc0
bb16e72
8e2bfc0
bb16e72
 
8e2bfc0
bb16e72
08137ac
bb16e72
 
 
08137ac
bb16e72
e6713e2
bb16e72
 
 
 
 
 
 
 
 
e6713e2
 
bb16e72
 
 
 
8e2bfc0
 
bb16e72
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
from janus.models import MultiModalityCausalLM, VLChatProcessor
from PIL import Image
from diffusers import AutoencoderKL
import numpy as np
import gradio as gr

# Configure device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Initialize medical imaging components
def load_medical_models():
    try:
        # Load processor and tokenizer
        processor = VLChatProcessor.from_pretrained("deepseek-ai/Janus-1.3B")
        
        # Load base model
        model = MultiModalityCausalLM.from_pretrained(
            "deepseek-ai/Janus-1.3B",
            torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32
        ).to(device).eval()
        
        # Load VAE for image processing
        vae = AutoencoderKL.from_pretrained(
            "stabilityai/sdxl-vae",
            torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32
        ).to(device).eval()
        
        return processor, model, vae
    except Exception as e:
        print(f"Error loading models: {str(e)}")
        raise

processor, model, vae = load_medical_models()

# Medical image analysis function
def medical_analysis(image, question, seed=42, top_p=0.95, temperature=0.1):
    try:
        # Set random seed for reproducibility
        torch.manual_seed(seed)
        np.random.seed(seed)
        
        # Prepare inputs
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image).convert("RGB")
            
        inputs = processor(
            text=question,
            images=[image],
            return_tensors="pt"
        ).to(device)
        
        # Generate analysis
        outputs = model.generate(
            inputs.input_ids,
            attention_mask=inputs.attention_mask,
            max_new_tokens=512,
            temperature=temperature,
            top_p=top_p
        )
        
        return processor.decode(outputs[0], skip_special_tokens=True)
    except Exception as e:
        return f"Analysis error: {str(e)}"

# Medical interface
with gr.Blocks(title="Medical Imaging Assistant") as demo:
    gr.Markdown("# Medical Imaging AI Assistant")
    
    with gr.Tab("Analysis"):
        with gr.Row():
            med_image = gr.Image(label="Input Image", type="pil")
            med_question = gr.Textbox(label="Clinical Query")
        analysis_output = gr.Textbox(label="Findings")
        gr.Examples(
            examples=[
                ["ultrasound_sample.jpg", "Identify any abnormalities in this ultrasound"],
                ["xray_sample.jpg", "Describe the bone structure visible in this X-ray"]
            ],
            inputs=[med_image, med_question]
        )
    
    med_question.submit(
        medical_analysis,
        inputs=[med_image, med_question],
        outputs=analysis_output
    )

demo.launch()