mgbam commited on
Commit
e6713e2
·
verified ·
1 Parent(s): 05acc1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -90
app.py CHANGED
@@ -1,45 +1,37 @@
1
- import gradio as gr
2
  import torch
3
  from janus.janusflow.models import MultiModalityCausalLM, VLChatProcessor
4
  from PIL import Image
5
  from diffusers.models import AutoencoderKL
6
  import numpy as np
7
- import spaces # Import spaces for ZeroGPU compatibility
8
 
 
9
  cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
10
 
11
- # Load model and processor
12
- model_path = "deepseek-ai/JanusFlow-1.3B"
13
  vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
14
  tokenizer = vl_chat_processor.tokenizer
15
 
16
  vl_gpt = MultiModalityCausalLM.from_pretrained(model_path)
17
  vl_gpt = vl_gpt.to(torch.bfloat16).to(cuda_device).eval()
18
 
19
- # remember to use bfloat16 dtype, this vae doesn't work with fp16
20
- vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
21
  vae = vae.to(torch.bfloat16).to(cuda_device).eval()
22
 
23
- # Multimodal Understanding function
24
  @torch.inference_mode()
25
- @spaces.GPU(duration=120)
26
  def multimodal_understanding(image, question, seed, top_p, temperature):
27
- # Clear CUDA cache before generating
28
  torch.cuda.empty_cache()
29
-
30
- # set seed
31
  torch.manual_seed(seed)
32
  np.random.seed(seed)
33
  torch.cuda.manual_seed(seed)
34
-
35
- # Medical image preprocessing (this is a placeholder, implement based on your specific needs)
36
- # NOTE: If input is DICOM or another medical format, add custom loading and preprocessing steps here
37
- # Example: if input is DICOM:
38
- # 1. load with pydicom.dcmread()
39
- # 2. normalize pixel values based on windowing/leveling if necessary
40
- # 3. convert to np.array
41
- # else: if the input is a regular numpy array (e.g. png or jpg) no action is needed, image = image
42
-
43
  conversation = [
44
  {
45
  "role": "User",
@@ -48,15 +40,14 @@ def multimodal_understanding(image, question, seed, top_p, temperature):
48
  },
49
  {"role": "Assistant", "content": ""},
50
  ]
51
-
52
  pil_images = [Image.fromarray(image)]
53
  prepare_inputs = vl_chat_processor(
54
  conversations=conversation, images=pil_images, force_batchify=True
55
  ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
56
-
57
-
58
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
59
-
60
  outputs = vl_gpt.language_model.generate(
61
  inputs_embeds=inputs_embeds,
62
  attention_mask=prepare_inputs.attention_mask,
@@ -69,14 +60,13 @@ def multimodal_understanding(image, question, seed, top_p, temperature):
69
  temperature=temperature,
70
  top_p=top_p,
71
  )
72
-
73
  answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
74
 
75
  return answer
76
 
77
-
78
  @torch.inference_mode()
79
- @spaces.GPU(duration=120)
80
  def generate(
81
  input_ids,
82
  cfg_weight: float = 2.0,
@@ -158,8 +148,8 @@ def unpack(dec, width, height, parallel_size=5):
158
  return visual_img
159
 
160
 
 
161
  @torch.inference_mode()
162
- @spaces.GPU(duration=120)
163
  def generate_image(prompt,
164
  seed=None,
165
  guidance=5,
@@ -185,80 +175,73 @@ def generate_image(prompt,
185
  num_inference_steps=num_inference_steps)
186
  return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(images.shape[0])]
187
 
188
-
189
 
190
  # Gradio interface
191
- with gr.Blocks() as demo:
192
- gr.Markdown(value="# Medical Image Analysis and Generation")
193
- # with gr.Row():
194
- with gr.Row():
195
- image_input = gr.Image(label="Medical Image Input")
196
- with gr.Column():
197
- question_input = gr.Textbox(label="Analysis Prompt (e.g., 'Identify tumor', 'Characterize lesion', 'Describe anatomic structures')")
198
- und_seed_input = gr.Number(label="Seed", precision=0, value=42)
199
- top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
200
- temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
201
-
202
- understanding_button = gr.Button("Analyze Image")
203
- understanding_output = gr.Textbox(label="Analysis Response")
204
-
205
- examples_inpainting = gr.Examples(
206
- label="Multimodal Understanding examples",
207
- examples=[
208
- [
209
- "Identify the tumor in the given image.",
210
- "./ct_scan.png" # Placeholder medical image path
211
- ],
212
- [
213
- "Characterize the lesion in the image. Is it malignant or benign?",
214
- "./mri_scan.png", # Placeholder medical image path
 
 
 
 
 
 
215
  ],
216
- [
217
- "Generate a report for the given medical image.",
218
- "./xray.png", # Placeholder medical image path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  ],
220
-
221
- ],
222
- inputs=[question_input, image_input],
223
- )
224
-
225
-
226
- gr.Markdown(value="# Medical Image Generation with Hugging Face Logo")
227
-
228
-
229
-
230
- with gr.Row():
231
- cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=2, step=0.5, label="CFG Weight")
232
- step_input = gr.Slider(minimum=1, maximum=50, value=30, step=1, label="Number of Inference Steps")
233
-
234
- prompt_input = gr.Textbox(label="Generation Prompt (e.g., 'Generate a CT scan with the Hugging Face logo', 'Create an MRI scan showing the Hugging Face logo', 'Render a medical x-ray with the Hugging Face logo.')")
235
- seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
236
-
237
- generation_button = gr.Button("Generate Images")
238
-
239
- image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
240
-
241
- examples_t2i = gr.Examples(
242
- label="Medical image generation examples with Hugging Face logo.",
243
- examples=[
244
- "Generate a CT scan with the Hugging Face logo clearly visible.",
245
- "Create an MRI scan showing the Hugging Face logo embedded within the tissue.",
246
- "Render a medical x-ray with the Hugging Face logo subtly visible in the background.",
247
- "Generate an ultrasound image with a faint Hugging Face logo on the screen",
248
- ],
249
- inputs=prompt_input,
250
- )
251
 
252
  understanding_button.click(
253
  multimodal_understanding,
254
  inputs=[image_input, question_input, und_seed_input, top_p, temperature],
255
  outputs=understanding_output
256
  )
257
-
258
  generation_button.click(
259
  fn=generate_image,
260
  inputs=[prompt_input, seed_input, cfg_weight_input, step_input],
261
  outputs=image_output
262
  )
263
-
264
- demo.launch(share=True, ssr_mode = False)
 
 
1
  import torch
2
  from janus.janusflow.models import MultiModalityCausalLM, VLChatProcessor
3
  from PIL import Image
4
  from diffusers.models import AutoencoderKL
5
  import numpy as np
6
+ import gradio as gr # Import gradio for UI
7
 
8
+ # CUDA availability check
9
  cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
+ print(f"Using device: {cuda_device}")
11
 
12
+ # Load model and processor (adjust path if needed)
13
+ model_path = "deepseek-ai/JanusFlow-1.3B" # You may need to change to your local path
14
  vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
15
  tokenizer = vl_chat_processor.tokenizer
16
 
17
  vl_gpt = MultiModalityCausalLM.from_pretrained(model_path)
18
  vl_gpt = vl_gpt.to(torch.bfloat16).to(cuda_device).eval()
19
 
20
+ # Load VAE for image generation
21
+ vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae") # You may need to change to your local path
22
  vae = vae.to(torch.bfloat16).to(cuda_device).eval()
23
 
24
+ # Multimodal Understanding function (modified for medical context)
25
  @torch.inference_mode()
 
26
  def multimodal_understanding(image, question, seed, top_p, temperature):
27
+ # Clear CUDA cache before generating to prevent memory leaks
28
  torch.cuda.empty_cache()
29
+
30
+ # Set seed for reproducibility
31
  torch.manual_seed(seed)
32
  np.random.seed(seed)
33
  torch.cuda.manual_seed(seed)
34
+
 
 
 
 
 
 
 
 
35
  conversation = [
36
  {
37
  "role": "User",
 
40
  },
41
  {"role": "Assistant", "content": ""},
42
  ]
43
+
44
  pil_images = [Image.fromarray(image)]
45
  prepare_inputs = vl_chat_processor(
46
  conversations=conversation, images=pil_images, force_batchify=True
47
  ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
48
+
 
49
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
50
+
51
  outputs = vl_gpt.language_model.generate(
52
  inputs_embeds=inputs_embeds,
53
  attention_mask=prepare_inputs.attention_mask,
 
60
  temperature=temperature,
61
  top_p=top_p,
62
  )
63
+
64
  answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
65
 
66
  return answer
67
 
68
+ # Image Generation Function (modified for medical context)
69
  @torch.inference_mode()
 
70
  def generate(
71
  input_ids,
72
  cfg_weight: float = 2.0,
 
148
  return visual_img
149
 
150
 
151
+ # Main image generation function
152
  @torch.inference_mode()
 
153
  def generate_image(prompt,
154
  seed=None,
155
  guidance=5,
 
175
  num_inference_steps=num_inference_steps)
176
  return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(images.shape[0])]
177
 
 
178
 
179
  # Gradio interface
180
+ with gr.Blocks(title="JanusFlow Medical Image Assistant") as demo:
181
+ gr.Markdown(value="# Medical Image Understanding and Generation")
182
+
183
+ with gr.Tab("Multimodal Understanding"):
184
+ with gr.Row():
185
+ image_input = gr.Image(label="Medical Image Input")
186
+ with gr.Column():
187
+ question_input = gr.Textbox(label="Medical Question")
188
+ und_seed_input = gr.Number(label="Seed", precision=0, value=42)
189
+ top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="Top P")
190
+ temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="Temperature")
191
+
192
+ understanding_button = gr.Button("Analyze Image")
193
+ understanding_output = gr.Textbox(label="Analysis Response")
194
+
195
+ examples_understanding = gr.Examples(
196
+ label="Examples: Image Analysis",
197
+ examples=[
198
+ [
199
+ "What are the visible structures in this ultrasound?",
200
+ "./ultrasound.jpeg"
201
+ ],
202
+ [
203
+ "Identify abnormalities in the image.",
204
+ "./cardiac_ultrasound.jpeg"
205
+ ],
206
+ [
207
+ "Describe the features and histological analysis in this image.",
208
+ "./histology.jpeg"
209
+ ],
210
  ],
211
+ inputs=[question_input, image_input],
212
+ )
213
+
214
+ with gr.Tab("Text-to-Image Generation"):
215
+ with gr.Row():
216
+ cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=2, step=0.5, label="CFG Weight")
217
+ step_input = gr.Slider(minimum=1, maximum=50, value=30, step=1, label="Inference Steps")
218
+
219
+ prompt_input = gr.Textbox(label="Medical Image Generation Prompt")
220
+ seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
221
+ generation_button = gr.Button("Generate Medical Image")
222
+ image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
223
+
224
+ examples_t2i = gr.Examples(
225
+ label="Examples: Image Generation",
226
+ examples=[
227
+ "Generate a coronal view of a brain MRI with a tumor.",
228
+ "Create an X-ray image showing a fractured femur.",
229
+ "Create an image of Histology of Liver Cirrhosis.",
230
  ],
231
+ inputs=prompt_input,
232
+ )
233
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
  understanding_button.click(
236
  multimodal_understanding,
237
  inputs=[image_input, question_input, und_seed_input, top_p, temperature],
238
  outputs=understanding_output
239
  )
240
+
241
  generation_button.click(
242
  fn=generate_image,
243
  inputs=[prompt_input, seed_input, cfg_weight_input, step_input],
244
  outputs=image_output
245
  )
246
+
247
+ demo.launch(share=True)