File size: 10,811 Bytes
d58d5be
2f9ea03
08137ac
2f9ea03
08137ac
d58d5be
08137ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d58d5be
08137ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d58d5be
08137ac
 
 
d58d5be
08137ac
 
 
 
 
d58d5be
08137ac
 
d58d5be
08137ac
2f9ea03
08137ac
2f9ea03
08137ac
 
 
 
 
 
 
 
 
d58d5be
08137ac
 
 
 
 
d58d5be
08137ac
 
 
 
 
2f9ea03
 
08137ac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
import gradio as gr
import torch
from janus.janusflow.models import MultiModalityCausalLM, VLChatProcessor
from PIL import Image
from diffusers.models import AutoencoderKL
import numpy as np
import spaces  # Import spaces for ZeroGPU compatibility

cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load model and processor
model_path = "deepseek-ai/JanusFlow-1.3B"
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer

vl_gpt = MultiModalityCausalLM.from_pretrained(model_path)
vl_gpt = vl_gpt.to(torch.bfloat16).to(cuda_device).eval()

# remember to use bfloat16 dtype, this vae doesn't work with fp16
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
vae = vae.to(torch.bfloat16).to(cuda_device).eval()

# Multimodal Understanding function
@torch.inference_mode()
@spaces.GPU(duration=120) 
def multimodal_understanding(image, question, seed, top_p, temperature):
    # Clear CUDA cache before generating
    torch.cuda.empty_cache()
    
    # set seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    
    # Medical image preprocessing (this is a placeholder, implement based on your specific needs)
    # NOTE: If input is DICOM or another medical format, add custom loading and preprocessing steps here
    # Example: if input is DICOM:
    # 1. load with pydicom.dcmread()
    # 2. normalize pixel values based on windowing/leveling if necessary
    # 3. convert to np.array
    # else: if the input is a regular numpy array (e.g. png or jpg) no action is needed, image = image
    
    conversation = [
        {
            "role": "User",
            "content": f"<image_placeholder>\n{question}",
            "images": [image],
        },
        {"role": "Assistant", "content": ""},
    ]
    
    pil_images = [Image.fromarray(image)]
    prepare_inputs = vl_chat_processor(
        conversations=conversation, images=pil_images, force_batchify=True
    ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
    
    
    inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
    
    outputs = vl_gpt.language_model.generate(
        inputs_embeds=inputs_embeds,
        attention_mask=prepare_inputs.attention_mask,
        pad_token_id=tokenizer.eos_token_id,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        max_new_tokens=512,
        do_sample=False if temperature == 0 else True,
        use_cache=True,
        temperature=temperature,
        top_p=top_p,
    )
    
    answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)

    return answer


@torch.inference_mode()
@spaces.GPU(duration=120) 
def generate(
    input_ids,
    cfg_weight: float = 2.0,
    num_inference_steps: int = 30
):
    # we generate 5 images at a time, *2 for CFG
    tokens = torch.stack([input_ids] * 10).cuda()
    tokens[5:, 1:] = vl_chat_processor.pad_id
    inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
    print(inputs_embeds.shape)

    # we remove the last <bog> token and replace it with t_emb later
    inputs_embeds = inputs_embeds[:, :-1, :] 
    
    # generate with rectified flow ode
    # step 1: encode with vision_gen_enc
    z = torch.randn((5, 4, 48, 48), dtype=torch.bfloat16).cuda()
    
    dt = 1.0 / num_inference_steps
    dt = torch.zeros_like(z).cuda().to(torch.bfloat16) + dt
    
    # step 2: run ode
    attention_mask = torch.ones((10, inputs_embeds.shape[1]+577)).to(vl_gpt.device)
    attention_mask[5:, 1:inputs_embeds.shape[1]] = 0
    attention_mask = attention_mask.int()
    for step in range(num_inference_steps):
        # prepare inputs for the llm
        z_input = torch.cat([z, z], dim=0) # for cfg
        t = step / num_inference_steps * 1000.
        t = torch.tensor([t] * z_input.shape[0]).to(dt)
        z_enc = vl_gpt.vision_gen_enc_model(z_input, t)
        z_emb, t_emb, hs = z_enc[0], z_enc[1], z_enc[2]
        z_emb = z_emb.view(z_emb.shape[0], z_emb.shape[1], -1).permute(0, 2, 1)
        z_emb = vl_gpt.vision_gen_enc_aligner(z_emb)
        llm_emb = torch.cat([inputs_embeds, t_emb.unsqueeze(1), z_emb], dim=1)

        # input to the llm
        # we apply attention mask for CFG: 1 for tokens that are not masked, 0 for tokens that are masked.
        if step == 0:
            outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb, 
                                             use_cache=True, 
                                             attention_mask=attention_mask,
                                             past_key_values=None)
            past_key_values = []
            for kv_cache in past_key_values:
                k, v = kv_cache[0], kv_cache[1]
                past_key_values.append((k[:, :, :inputs_embeds.shape[1], :], v[:, :, :inputs_embeds.shape[1], :]))
            past_key_values = tuple(past_key_values)
        else:
            outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb, 
                                             use_cache=True, 
                                             attention_mask=attention_mask,
                                             past_key_values=past_key_values)
        hidden_states = outputs.last_hidden_state
        
        # transform hidden_states back to v
        hidden_states = vl_gpt.vision_gen_dec_aligner(vl_gpt.vision_gen_dec_aligner_norm(hidden_states[:, -576:, :]))
        hidden_states = hidden_states.reshape(z_emb.shape[0], 24, 24, 768).permute(0, 3, 1, 2)
        v = vl_gpt.vision_gen_dec_model(hidden_states, hs, t_emb)
        v_cond, v_uncond = torch.chunk(v, 2)
        v = cfg_weight * v_cond - (cfg_weight-1.) * v_uncond
        z = z + dt * v
        
    # step 3: decode with vision_gen_dec and sdxl vae
    decoded_image = vae.decode(z / vae.config.scaling_factor).sample
    
    images = decoded_image.float().clip_(-1., 1.).permute(0,2,3,1).cpu().numpy()
    images = ((images+1) / 2. * 255).astype(np.uint8)
    
    return images
    
def unpack(dec, width, height, parallel_size=5):
    dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
    dec = np.clip((dec + 1) / 2 * 255, 0, 255)

    visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
    visual_img[:, :, :] = dec

    return visual_img


@torch.inference_mode()
@spaces.GPU(duration=120) 
def generate_image(prompt,
                   seed=None,
                   guidance=5,
                   num_inference_steps=30):
    # Clear CUDA cache and avoid tracking gradients
    torch.cuda.empty_cache()
    # Set the seed for reproducible results
    if seed is not None:
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        np.random.seed(seed)
    
    with torch.no_grad():
        messages = [{'role': 'User', 'content': prompt},
                    {'role': 'Assistant', 'content': ''}]
        text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
                                                                   sft_format=vl_chat_processor.sft_format,
                                                                   system_prompt='')
        text = text + vl_chat_processor.image_start_tag
        input_ids = torch.LongTensor(tokenizer.encode(text))
        images = generate(input_ids,
                                   cfg_weight=guidance,
                                   num_inference_steps=num_inference_steps)
        return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(images.shape[0])]

        

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown(value="# Medical Image Analysis and Generation")
    # with gr.Row():
    with gr.Row():
        image_input = gr.Image(label="Medical Image Input")
        with gr.Column():
            question_input = gr.Textbox(label="Analysis Prompt (e.g., 'Identify tumor', 'Characterize lesion', 'Describe anatomic structures')")
            und_seed_input = gr.Number(label="Seed", precision=0, value=42)
            top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
            temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
        
    understanding_button = gr.Button("Analyze Image")
    understanding_output = gr.Textbox(label="Analysis Response")

    examples_inpainting = gr.Examples(
        label="Multimodal Understanding examples",
         examples=[
            [
              "Identify the tumor in the given image.",
              "./ct_scan.png"  # Placeholder medical image path
            ],
             [
                 "Characterize the lesion in the image. Is it malignant or benign?",
                 "./mri_scan.png",  # Placeholder medical image path
            ],
            [
                 "Generate a report for the given medical image.",
                 "./xray.png",  # Placeholder medical image path
            ],
           
         ],
        inputs=[question_input, image_input],
    )
    
        
    gr.Markdown(value="# Medical Image Generation with Hugging Face Logo")

    
    
    with gr.Row():
        cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=2, step=0.5, label="CFG Weight")
        step_input = gr.Slider(minimum=1, maximum=50, value=30, step=1, label="Number of Inference Steps")

    prompt_input = gr.Textbox(label="Generation Prompt (e.g., 'Generate a CT scan with the Hugging Face logo', 'Create an MRI scan showing the Hugging Face logo', 'Render a medical x-ray with the Hugging Face logo.')")
    seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)

    generation_button = gr.Button("Generate Images")

    image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)

    examples_t2i = gr.Examples(
        label="Medical image generation examples with Hugging Face logo.",
        examples=[
            "Generate a CT scan with the Hugging Face logo clearly visible.",
            "Create an MRI scan showing the Hugging Face logo embedded within the tissue.",
            "Render a medical x-ray with the Hugging Face logo subtly visible in the background.",
            "Generate an ultrasound image with a faint Hugging Face logo on the screen",
        ],
        inputs=prompt_input,
    )
    
    understanding_button.click(
        multimodal_understanding,
        inputs=[image_input, question_input, und_seed_input, top_p, temperature],
        outputs=understanding_output
    )
    
    generation_button.click(
        fn=generate_image,
        inputs=[prompt_input, seed_input, cfg_weight_input, step_input],
        outputs=image_output
    )

demo.launch(share=True, ssr_mode = False)