multimodalart HF Staff commited on
Commit
c8ad832
·
verified ·
1 Parent(s): c138d2a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +214 -0
app.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import random
4
+ import spaces
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ import gradio as gr
9
+ from diffusers import DiffusionPipeline
10
+ from blip3o.conversation import conv_templates
11
+ from blip3o.model.builder import load_pretrained_model
12
+ from blip3o.utils import disable_torch_init
13
+ from blip3o.mm_utils import get_model_name_from_path
14
+ from qwen_vl_utils import process_vision_info
15
+ from huggingface_hub import snapshot_download
16
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
17
+
18
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
19
+
20
+ # Constants
21
+ MAX_SEED = 10000
22
+
23
+ HUB_MODEL_ID = "BLIP3o/BLIP3o-Model"
24
+ model_snapshot_path = snapshot_download(repo_id=HUB_MODEL_ID)
25
+ diffusion_path = os.path.join(model_snapshot_path, "diffusion-decoder")
26
+
27
+ def set_global_seed(seed: int = 42):
28
+ random.seed(seed)
29
+ np.random.seed(seed)
30
+ torch.manual_seed(seed)
31
+ torch.cuda.manual_seed_all(seed)
32
+
33
+ def add_template(prompt_list: list[str]) -> str:
34
+ conv = conv_templates['qwen'].copy()
35
+ conv.append_message(conv.roles[0], prompt_list[0])
36
+ conv.append_message(conv.roles[1], None)
37
+ return conv.get_prompt()
38
+
39
+ def make_prompt(text: str) -> list[str]:
40
+ raw = f"Please generate image based on the following caption: {text}"
41
+ return [add_template([raw])]
42
+
43
+ def randomize_seed_fn(seed: int, randomize: bool) -> int:
44
+ return random.randint(0, MAX_SEED) if randomize else seed
45
+
46
+ def generate_image(prompt: str, seed: int, guidance_scale: float, randomize: bool) -> list[Image.Image]:
47
+ seed = randomize_seed_fn(seed, randomize)
48
+ set_global_seed(seed)
49
+ formatted = make_prompt(prompt)
50
+ images = []
51
+ for _ in range(4):
52
+ out = pipe(formatted, guidance_scale=guidance_scale)
53
+ images.append(out.image)
54
+ return images
55
+
56
+ @spaces.GPU
57
+ def process_image(prompt: str, img: Image.Image) -> str:
58
+ messages = [{
59
+ "role": "user",
60
+ "content": [
61
+ {"type": "image", "image": img},
62
+ {"type": "text", "text": prompt},
63
+ ],
64
+ }]
65
+ text_prompt_for_qwen = processor.apply_chat_template(
66
+ messages, tokenize=False, add_generation_prompt=True
67
+ )
68
+ image_inputs, video_inputs = process_vision_info(messages)
69
+ inputs = processor(
70
+ text=[text_prompt_for_qwen],
71
+ images=image_inputs,
72
+ videos=video_inputs,
73
+ padding=True,
74
+ return_tensors="pt",
75
+ ).to('cuda:0')
76
+ generated_ids = multi_model.generate(**inputs, max_new_tokens=1024)
77
+ input_token_len = inputs.input_ids.shape[1]
78
+ generated_ids_trimmed = generated_ids[:, input_token_len:]
79
+ output_text = processor.batch_decode(
80
+ generated_ids_trimmed, skip_special_tokens=True,
81
+ clean_up_tokenization_spaces=False
82
+ )[0]
83
+ return output_text
84
+
85
+ # Initialize model + pipeline
86
+ disable_torch_init()
87
+ model_path = os.path.expanduser(sys.argv[1])
88
+ tokenizer, multi_model, _ = load_pretrained_model(
89
+ model_path, None, get_model_name_from_path(model_path)
90
+ )
91
+ pipe = DiffusionPipeline.from_pretrained(
92
+ diffusion_path,
93
+ custom_pipeline="pipeline_llava_gen",
94
+ torch_dtype=torch.bfloat16,
95
+ use_safetensors=True,
96
+ variant="bf16",
97
+ multimodal_encoder=multi_model,
98
+ tokenizer=tokenizer,
99
+ safety_checker=None
100
+ )
101
+ pipe.vae.to('cuda')
102
+ pipe.unet.to('cuda')
103
+
104
+ # Gradio UI
105
+ with gr.Blocks(title="BLIP3-o") as demo:
106
+ with gr.Row():
107
+ with gr.Column(scale=2):
108
+ image_input = gr.Image(label="Input Image (optional)", type="pil")
109
+ prompt_input = gr.Textbox(
110
+ label="Prompt",
111
+ placeholder="Describe the image you want...",
112
+ lines=1
113
+ )
114
+ seed_slider = gr.Slider(
115
+ label="Seed",
116
+ minimum=0, maximum=int(MAX_SEED),
117
+ step=1, value=42
118
+ )
119
+ randomize_checkbox = gr.Checkbox(
120
+ label="Randomize seed", value=False
121
+ )
122
+ guidance_slider = gr.Slider(
123
+ label="Guidance Scale",
124
+ minimum=1.0, maximum=30.0,
125
+ step=0.5, value=3.0
126
+ )
127
+ run_btn = gr.Button("Run")
128
+ clean_btn = gr.Button("Clean All")
129
+
130
+
131
+ text_only = [
132
+ [None, "A cute cat."],
133
+ [None, "A young woman with freckles wearing a straw hat, standing in a golden wheat field."],
134
+ [None, "A group of friends having a picnic in the park."]
135
+ ]
136
+
137
+ image_plus_text = [
138
+ [f"animal-compare.png", "Are these two pictures showing the same kind of animal?"],
139
+ [f"funny_image.jpeg", "Why is this image funny?"],
140
+ ]
141
+
142
+ all_examples = text_only + image_plus_text
143
+
144
+ gr.Examples(
145
+ examples=all_examples,
146
+ inputs=[image_input, prompt_input],
147
+ cache_examples=False,
148
+ label="Try a sample (image generation (text input) or image understanding (image + text))"
149
+ )
150
+
151
+
152
+
153
+ with gr.Column(scale=3):
154
+ output_gallery = gr.Gallery(label="Generated Images", columns=4)
155
+ output_text = gr.Textbox(label="Generated Text", visible=False)
156
+
157
+ def run_all(img, prompt, seed, guidance, randomize):
158
+ if img is not None:
159
+ txt = process_image(prompt, img)
160
+ return (
161
+ gr.update(value=[], visible=False),
162
+ gr.update(value=txt, visible=True)
163
+ )
164
+ else:
165
+ imgs = generate_image(prompt, seed, guidance, randomize)
166
+ return (
167
+ gr.update(value=imgs, visible=True),
168
+ gr.update(value="", visible=False)
169
+ )
170
+
171
+ def clean_all():
172
+ return (
173
+ gr.update(value=None),
174
+ gr.update(value=""),
175
+ gr.update(value=42),
176
+ gr.update(value=False),
177
+ gr.update(value=3.0),
178
+ gr.update(value=[], visible=False),
179
+ gr.update(value="", visible=False)
180
+ )
181
+
182
+ # Chain seed randomization → run_all when clicking “Run”
183
+ run_btn.click(
184
+ fn=randomize_seed_fn,
185
+ inputs=[seed_slider, randomize_checkbox],
186
+ outputs=seed_slider
187
+ ).then(
188
+ fn=run_all,
189
+ inputs=[image_input, prompt_input, seed_slider, guidance_slider, randomize_checkbox],
190
+ outputs=[output_gallery, output_text]
191
+ )
192
+
193
+ # Bind Enter on the prompt textbox to the same chain
194
+ prompt_input.submit(
195
+ fn=randomize_seed_fn,
196
+ inputs=[seed_slider, randomize_checkbox],
197
+ outputs=seed_slider
198
+ ).then(
199
+ fn=run_all,
200
+ inputs=[image_input, prompt_input, seed_slider, guidance_slider, randomize_checkbox],
201
+ outputs=[output_gallery, output_text]
202
+ )
203
+
204
+ # Clean all inputs/outputs
205
+ clean_btn.click(
206
+ fn=clean_all,
207
+ inputs=[],
208
+ outputs=[image_input, prompt_input, seed_slider,
209
+ randomize_checkbox, guidance_slider,
210
+ output_gallery, output_text]
211
+ )
212
+
213
+ if __name__ == "__main__":
214
+ demo.launch(share=True)