Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,45 +1,37 @@
|
|
1 |
-
import gradio as gr
|
2 |
import torch
|
3 |
from janus.janusflow.models import MultiModalityCausalLM, VLChatProcessor
|
4 |
from PIL import Image
|
5 |
from diffusers.models import AutoencoderKL
|
6 |
import numpy as np
|
7 |
-
import
|
8 |
|
|
|
9 |
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
10 |
|
11 |
-
# Load model and processor
|
12 |
-
model_path = "deepseek-ai/JanusFlow-1.3B"
|
13 |
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
14 |
tokenizer = vl_chat_processor.tokenizer
|
15 |
|
16 |
vl_gpt = MultiModalityCausalLM.from_pretrained(model_path)
|
17 |
vl_gpt = vl_gpt.to(torch.bfloat16).to(cuda_device).eval()
|
18 |
|
19 |
-
#
|
20 |
-
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
|
21 |
vae = vae.to(torch.bfloat16).to(cuda_device).eval()
|
22 |
|
23 |
-
# Multimodal Understanding function
|
24 |
@torch.inference_mode()
|
25 |
-
@spaces.GPU(duration=120)
|
26 |
def multimodal_understanding(image, question, seed, top_p, temperature):
|
27 |
-
# Clear CUDA cache before generating
|
28 |
torch.cuda.empty_cache()
|
29 |
-
|
30 |
-
#
|
31 |
torch.manual_seed(seed)
|
32 |
np.random.seed(seed)
|
33 |
torch.cuda.manual_seed(seed)
|
34 |
-
|
35 |
-
# Medical image preprocessing (this is a placeholder, implement based on your specific needs)
|
36 |
-
# NOTE: If input is DICOM or another medical format, add custom loading and preprocessing steps here
|
37 |
-
# Example: if input is DICOM:
|
38 |
-
# 1. load with pydicom.dcmread()
|
39 |
-
# 2. normalize pixel values based on windowing/leveling if necessary
|
40 |
-
# 3. convert to np.array
|
41 |
-
# else: if the input is a regular numpy array (e.g. png or jpg) no action is needed, image = image
|
42 |
-
|
43 |
conversation = [
|
44 |
{
|
45 |
"role": "User",
|
@@ -48,15 +40,14 @@ def multimodal_understanding(image, question, seed, top_p, temperature):
|
|
48 |
},
|
49 |
{"role": "Assistant", "content": ""},
|
50 |
]
|
51 |
-
|
52 |
pil_images = [Image.fromarray(image)]
|
53 |
prepare_inputs = vl_chat_processor(
|
54 |
conversations=conversation, images=pil_images, force_batchify=True
|
55 |
).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
|
56 |
-
|
57 |
-
|
58 |
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
59 |
-
|
60 |
outputs = vl_gpt.language_model.generate(
|
61 |
inputs_embeds=inputs_embeds,
|
62 |
attention_mask=prepare_inputs.attention_mask,
|
@@ -69,14 +60,13 @@ def multimodal_understanding(image, question, seed, top_p, temperature):
|
|
69 |
temperature=temperature,
|
70 |
top_p=top_p,
|
71 |
)
|
72 |
-
|
73 |
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
|
74 |
|
75 |
return answer
|
76 |
|
77 |
-
|
78 |
@torch.inference_mode()
|
79 |
-
@spaces.GPU(duration=120)
|
80 |
def generate(
|
81 |
input_ids,
|
82 |
cfg_weight: float = 2.0,
|
@@ -158,8 +148,8 @@ def unpack(dec, width, height, parallel_size=5):
|
|
158 |
return visual_img
|
159 |
|
160 |
|
|
|
161 |
@torch.inference_mode()
|
162 |
-
@spaces.GPU(duration=120)
|
163 |
def generate_image(prompt,
|
164 |
seed=None,
|
165 |
guidance=5,
|
@@ -185,80 +175,73 @@ def generate_image(prompt,
|
|
185 |
num_inference_steps=num_inference_steps)
|
186 |
return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(images.shape[0])]
|
187 |
|
188 |
-
|
189 |
|
190 |
# Gradio interface
|
191 |
-
with gr.Blocks() as demo:
|
192 |
-
gr.Markdown(value="# Medical Image
|
193 |
-
|
194 |
-
with gr.
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
[
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
],
|
216 |
-
[
|
217 |
-
|
218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
],
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
)
|
224 |
-
|
225 |
-
|
226 |
-
gr.Markdown(value="# Medical Image Generation with Hugging Face Logo")
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
with gr.Row():
|
231 |
-
cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=2, step=0.5, label="CFG Weight")
|
232 |
-
step_input = gr.Slider(minimum=1, maximum=50, value=30, step=1, label="Number of Inference Steps")
|
233 |
-
|
234 |
-
prompt_input = gr.Textbox(label="Generation Prompt (e.g., 'Generate a CT scan with the Hugging Face logo', 'Create an MRI scan showing the Hugging Face logo', 'Render a medical x-ray with the Hugging Face logo.')")
|
235 |
-
seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
|
236 |
-
|
237 |
-
generation_button = gr.Button("Generate Images")
|
238 |
-
|
239 |
-
image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
|
240 |
-
|
241 |
-
examples_t2i = gr.Examples(
|
242 |
-
label="Medical image generation examples with Hugging Face logo.",
|
243 |
-
examples=[
|
244 |
-
"Generate a CT scan with the Hugging Face logo clearly visible.",
|
245 |
-
"Create an MRI scan showing the Hugging Face logo embedded within the tissue.",
|
246 |
-
"Render a medical x-ray with the Hugging Face logo subtly visible in the background.",
|
247 |
-
"Generate an ultrasound image with a faint Hugging Face logo on the screen",
|
248 |
-
],
|
249 |
-
inputs=prompt_input,
|
250 |
-
)
|
251 |
|
252 |
understanding_button.click(
|
253 |
multimodal_understanding,
|
254 |
inputs=[image_input, question_input, und_seed_input, top_p, temperature],
|
255 |
outputs=understanding_output
|
256 |
)
|
257 |
-
|
258 |
generation_button.click(
|
259 |
fn=generate_image,
|
260 |
inputs=[prompt_input, seed_input, cfg_weight_input, step_input],
|
261 |
outputs=image_output
|
262 |
)
|
263 |
-
|
264 |
-
demo.launch(share=True
|
|
|
|
|
1 |
import torch
|
2 |
from janus.janusflow.models import MultiModalityCausalLM, VLChatProcessor
|
3 |
from PIL import Image
|
4 |
from diffusers.models import AutoencoderKL
|
5 |
import numpy as np
|
6 |
+
import gradio as gr # Import gradio for UI
|
7 |
|
8 |
+
# CUDA availability check
|
9 |
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
10 |
+
print(f"Using device: {cuda_device}")
|
11 |
|
12 |
+
# Load model and processor (adjust path if needed)
|
13 |
+
model_path = "deepseek-ai/JanusFlow-1.3B" # You may need to change to your local path
|
14 |
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
15 |
tokenizer = vl_chat_processor.tokenizer
|
16 |
|
17 |
vl_gpt = MultiModalityCausalLM.from_pretrained(model_path)
|
18 |
vl_gpt = vl_gpt.to(torch.bfloat16).to(cuda_device).eval()
|
19 |
|
20 |
+
# Load VAE for image generation
|
21 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae") # You may need to change to your local path
|
22 |
vae = vae.to(torch.bfloat16).to(cuda_device).eval()
|
23 |
|
24 |
+
# Multimodal Understanding function (modified for medical context)
|
25 |
@torch.inference_mode()
|
|
|
26 |
def multimodal_understanding(image, question, seed, top_p, temperature):
|
27 |
+
# Clear CUDA cache before generating to prevent memory leaks
|
28 |
torch.cuda.empty_cache()
|
29 |
+
|
30 |
+
# Set seed for reproducibility
|
31 |
torch.manual_seed(seed)
|
32 |
np.random.seed(seed)
|
33 |
torch.cuda.manual_seed(seed)
|
34 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
conversation = [
|
36 |
{
|
37 |
"role": "User",
|
|
|
40 |
},
|
41 |
{"role": "Assistant", "content": ""},
|
42 |
]
|
43 |
+
|
44 |
pil_images = [Image.fromarray(image)]
|
45 |
prepare_inputs = vl_chat_processor(
|
46 |
conversations=conversation, images=pil_images, force_batchify=True
|
47 |
).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
|
48 |
+
|
|
|
49 |
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
50 |
+
|
51 |
outputs = vl_gpt.language_model.generate(
|
52 |
inputs_embeds=inputs_embeds,
|
53 |
attention_mask=prepare_inputs.attention_mask,
|
|
|
60 |
temperature=temperature,
|
61 |
top_p=top_p,
|
62 |
)
|
63 |
+
|
64 |
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
|
65 |
|
66 |
return answer
|
67 |
|
68 |
+
# Image Generation Function (modified for medical context)
|
69 |
@torch.inference_mode()
|
|
|
70 |
def generate(
|
71 |
input_ids,
|
72 |
cfg_weight: float = 2.0,
|
|
|
148 |
return visual_img
|
149 |
|
150 |
|
151 |
+
# Main image generation function
|
152 |
@torch.inference_mode()
|
|
|
153 |
def generate_image(prompt,
|
154 |
seed=None,
|
155 |
guidance=5,
|
|
|
175 |
num_inference_steps=num_inference_steps)
|
176 |
return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(images.shape[0])]
|
177 |
|
|
|
178 |
|
179 |
# Gradio interface
|
180 |
+
with gr.Blocks(title="JanusFlow Medical Image Assistant") as demo:
|
181 |
+
gr.Markdown(value="# Medical Image Understanding and Generation")
|
182 |
+
|
183 |
+
with gr.Tab("Multimodal Understanding"):
|
184 |
+
with gr.Row():
|
185 |
+
image_input = gr.Image(label="Medical Image Input")
|
186 |
+
with gr.Column():
|
187 |
+
question_input = gr.Textbox(label="Medical Question")
|
188 |
+
und_seed_input = gr.Number(label="Seed", precision=0, value=42)
|
189 |
+
top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="Top P")
|
190 |
+
temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="Temperature")
|
191 |
+
|
192 |
+
understanding_button = gr.Button("Analyze Image")
|
193 |
+
understanding_output = gr.Textbox(label="Analysis Response")
|
194 |
+
|
195 |
+
examples_understanding = gr.Examples(
|
196 |
+
label="Examples: Image Analysis",
|
197 |
+
examples=[
|
198 |
+
[
|
199 |
+
"What are the visible structures in this ultrasound?",
|
200 |
+
"./ultrasound.jpeg"
|
201 |
+
],
|
202 |
+
[
|
203 |
+
"Identify abnormalities in the image.",
|
204 |
+
"./cardiac_ultrasound.jpeg"
|
205 |
+
],
|
206 |
+
[
|
207 |
+
"Describe the features and histological analysis in this image.",
|
208 |
+
"./histology.jpeg"
|
209 |
+
],
|
210 |
],
|
211 |
+
inputs=[question_input, image_input],
|
212 |
+
)
|
213 |
+
|
214 |
+
with gr.Tab("Text-to-Image Generation"):
|
215 |
+
with gr.Row():
|
216 |
+
cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=2, step=0.5, label="CFG Weight")
|
217 |
+
step_input = gr.Slider(minimum=1, maximum=50, value=30, step=1, label="Inference Steps")
|
218 |
+
|
219 |
+
prompt_input = gr.Textbox(label="Medical Image Generation Prompt")
|
220 |
+
seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
|
221 |
+
generation_button = gr.Button("Generate Medical Image")
|
222 |
+
image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
|
223 |
+
|
224 |
+
examples_t2i = gr.Examples(
|
225 |
+
label="Examples: Image Generation",
|
226 |
+
examples=[
|
227 |
+
"Generate a coronal view of a brain MRI with a tumor.",
|
228 |
+
"Create an X-ray image showing a fractured femur.",
|
229 |
+
"Create an image of Histology of Liver Cirrhosis.",
|
230 |
],
|
231 |
+
inputs=prompt_input,
|
232 |
+
)
|
233 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
|
235 |
understanding_button.click(
|
236 |
multimodal_understanding,
|
237 |
inputs=[image_input, question_input, und_seed_input, top_p, temperature],
|
238 |
outputs=understanding_output
|
239 |
)
|
240 |
+
|
241 |
generation_button.click(
|
242 |
fn=generate_image,
|
243 |
inputs=[prompt_input, seed_input, cfg_weight_input, step_input],
|
244 |
outputs=image_output
|
245 |
)
|
246 |
+
|
247 |
+
demo.launch(share=True)
|