yichenchenchen commited on
Commit
9a00163
·
verified ·
1 Parent(s): 932089b

Create inferencer.py

Browse files
Files changed (1) hide show
  1. inferencer.py +389 -0
inferencer.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import torch
5
+ import math
6
+ from PIL import Image
7
+ from typing import List, Optional
8
+ from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor
9
+ from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL, BitsAndBytesConfig
10
+ from unipicv2.pipeline_stable_diffusion_3_kontext import StableDiffusion3KontextPipeline
11
+ from unipicv2.transformer_sd3_kontext import SD3Transformer2DKontextModel
12
+ from unipicv2.stable_diffusion_3_conditioner import StableDiffusion3Conditioner
13
+ import spaces
14
+
15
+ class UniPicV2Inferencer:
16
+ def __init__(
17
+ self,
18
+ model_path: str,
19
+ qwen_vl_path: str,
20
+ quant: str = "int4", # {"int4", "fp16"}
21
+ image_size: int = 512,
22
+ default_negative_prompt: str = "blurry, low quality"
23
+ ):
24
+ self.model_path = model_path
25
+ self.qwen_vl_path = qwen_vl_path
26
+ self.quant = quant
27
+ self.image_size = image_size
28
+ self.default_negative_prompt = default_negative_prompt
29
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ self.pipeline = None#self._init_pipeline()
31
+
32
+ def _init_pipeline(self) -> StableDiffusion3KontextPipeline:
33
+ print("Initializing UniPicV2 pipeline...")
34
+
35
+ # ===== 1. Initialize BNB Config =====
36
+ bnb4 = BitsAndBytesConfig(
37
+ load_in_4bit=True,
38
+ bnb_4bit_use_double_quant=True,
39
+ bnb_4bit_quant_type="nf4",
40
+ bnb_4bit_compute_dtype=torch.float16,
41
+ )
42
+
43
+ # ===== 2. Load SD3 Transformer =====
44
+ if self.quant == "int4":
45
+ transformer = SD3Transformer2DKontextModel.from_pretrained(
46
+ self.model_path, subfolder="transformer",
47
+ quantization_config=bnb4, device_map="auto", low_cpu_mem_usage=True
48
+ )
49
+ else:
50
+ transformer = SD3Transformer2DKontextModel.from_pretrained(
51
+ self.model_path, subfolder="transformer",
52
+ torch_dtype=torch.float16, device_map="auto", low_cpu_mem_usage=True
53
+ )
54
+
55
+ # ===== 3. Load VAE =====
56
+ vae = AutoencoderKL.from_pretrained(
57
+ self.model_path, subfolder="vae",
58
+ torch_dtype=torch.float16, device_map="auto", low_cpu_mem_usage=True
59
+ ).to(self.device)
60
+
61
+ # ===== 4. Load Qwen2.5-VL (LMM) =====
62
+ try:
63
+ self.lmm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
64
+ self.qwen_vl_path,
65
+ torch_dtype=torch.float16,
66
+ attn_implementation="flash_attention_2",
67
+ device_map="auto",
68
+ ).to(self.device)
69
+ print("**"*20)
70
+ except Exception:
71
+ self.lmm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
72
+ self.qwen_vl_path,
73
+ torch_dtype=torch.float16,
74
+ attn_implementation="sdpa",
75
+ device_map="auto",
76
+ ).to(self.device)
77
+
78
+ # ===== 5. Load Processor =====
79
+ self.processor = Qwen2_5_VLProcessor.from_pretrained(self.qwen_vl_path, use_fast=False)
80
+
81
+ if hasattr(self.processor, "chat_template") and self.processor.chat_template:
82
+ self.processor.chat_template = self.processor.chat_template.replace(
83
+ "{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}",
84
+ ""
85
+ )
86
+
87
+ # ===== 6. Load Conditioner =====
88
+ self.conditioner = StableDiffusion3Conditioner.from_pretrained(
89
+ self.model_path, subfolder="conditioner",
90
+ torch_dtype=torch.float16, low_cpu_mem_usage=True
91
+ ).to(self.device)
92
+
93
+ # ===== 7. Load Scheduler =====
94
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
95
+ self.model_path, subfolder="scheduler"
96
+ )
97
+
98
+ # ===== 8. Create Pipeline =====
99
+ pipeline = StableDiffusion3KontextPipeline(
100
+ transformer=transformer,
101
+ vae=vae,
102
+ text_encoder=None,
103
+ tokenizer=None,
104
+ text_encoder_2=None,
105
+ tokenizer_2=None,
106
+ text_encoder_3=None,
107
+ tokenizer_3=None,
108
+ scheduler=scheduler
109
+ )
110
+
111
+ try:
112
+ pipeline.enable_vae_slicing()
113
+ pipeline.enable_vae_tiling()
114
+ pipeline.enable_model_cpu_offload()
115
+ except Exception:
116
+ print("Note: Could not enable all memory-saving features")
117
+
118
+ print("Pipeline initialization complete!")
119
+ return pipeline
120
+
121
+ def _prepare_text_inputs(self, prompt: str, negative_prompt: str = None):
122
+ negative_prompt = negative_prompt or self.default_negative_prompt
123
+
124
+ messages = [
125
+ [{"role": "user", "content": [{"type": "text", "text": prompt}]}],
126
+ [{"role": "user", "content": [{"type": "text", "text": negative_prompt}]}]
127
+ ]
128
+
129
+ texts = [
130
+ self.processor.apply_chat_template(m, tokenize=False, add_generation_prompt=True)
131
+ for m in messages
132
+ ]
133
+
134
+ inputs = self.processor(
135
+ text=texts,
136
+ images=None,
137
+ padding=True,
138
+ return_tensors="pt"
139
+ )
140
+
141
+ return inputs
142
+
143
+ def _prepare_image_inputs(self, image: Image.Image, prompt: str, negative_prompt: str = None):
144
+ negative_prompt = negative_prompt or self.default_negative_prompt
145
+
146
+ messages = [
147
+ [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}],
148
+ [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": negative_prompt}]}]
149
+ ]
150
+
151
+ texts = [
152
+ self.processor.apply_chat_template(m, tokenize=False, add_generation_prompt=True)
153
+ for m in messages
154
+ ]
155
+
156
+ min_pixels = max_pixels = int(image.height * 28 / 32 * image.width * 28 / 32)
157
+
158
+ inputs = self.processor(
159
+ text=texts,
160
+ images=[image] * 2,
161
+ min_pixels=min_pixels,
162
+ max_pixels=max_pixels,
163
+ padding=True,
164
+ return_tensors="pt"
165
+ )
166
+
167
+ return inputs
168
+
169
+ def _process_inputs(self, inputs: dict, num_queries: int):
170
+ # Ensure all tensors are on the correct device
171
+ inputs = {
172
+ k: v.to(self.device) if isinstance(v, torch.Tensor) else v
173
+ for k, v in inputs.items()
174
+ }
175
+
176
+ input_ids = inputs["input_ids"]
177
+ attention_mask = inputs["attention_mask"]
178
+
179
+ # Pad with meta queries
180
+ pad_ids = torch.zeros((input_ids.size(0), num_queries),
181
+ dtype=input_ids.dtype, device=self.device)
182
+ pad_mask = torch.ones((attention_mask.size(0), num_queries),
183
+ dtype=attention_mask.dtype, device=self.device)
184
+
185
+ input_ids = torch.cat([input_ids, pad_ids], dim=1)
186
+ attention_mask = torch.cat([attention_mask, pad_mask], dim=1)
187
+
188
+ # Get input embeddings
189
+ inputs_embeds = self.lmm.get_input_embeddings()(input_ids)
190
+
191
+ # Ensure meta queries are on correct device
192
+ self.conditioner.meta_queries.data = self.conditioner.meta_queries.data.to(self.device)
193
+ inputs_embeds[:, -num_queries:] = self.conditioner.meta_queries[None].expand(2, -1, -1)
194
+
195
+ # Handle image embeddings if present
196
+ if "pixel_values" in inputs:
197
+ image_embeds = self.lmm.visual(
198
+ inputs["pixel_values"].to(self.device),
199
+ grid_thw=inputs["image_grid_thw"].to(self.device)
200
+ )
201
+ image_token_id = self.processor.tokenizer.convert_tokens_to_ids('<|image_pad|>')
202
+ mask_img = (input_ids == image_token_id)
203
+ inputs_embeds[mask_img] = image_embeds
204
+
205
+ # Forward through LMM
206
+ if hasattr(self.lmm.model, "rope_deltas"):
207
+ self.lmm.model.rope_deltas = None
208
+
209
+ outputs = self.lmm.model(
210
+ inputs_embeds=inputs_embeds.to(self.device),
211
+ attention_mask=attention_mask.to(self.device),
212
+ image_grid_thw=inputs.get("image_grid_thw", None),
213
+ use_cache=False
214
+ )
215
+
216
+ hidden_states = outputs.last_hidden_state[:, -num_queries:]
217
+ hidden_states = hidden_states.to(self.device)
218
+
219
+ # Get prompt embeds
220
+ prompt_embeds, pooled_prompt_embeds = self.conditioner(hidden_states)
221
+
222
+ return {
223
+ "prompt_embeds": prompt_embeds[:1],
224
+ "pooled_prompt_embeds": pooled_prompt_embeds[:1],
225
+ "negative_prompt_embeds": prompt_embeds[1:],
226
+ "negative_pooled_prompt_embeds": pooled_prompt_embeds[1:]
227
+ }
228
+
229
+ def _resize_image(self, image: Image.Image, size: int) -> Image.Image:
230
+ w, h = image.size
231
+ if w >= h:
232
+ new_w = size
233
+ new_h = int(h * (new_w / w))
234
+ new_h = (new_h // 32) * 32
235
+ else:
236
+ new_h = size
237
+ new_w = int(w * (new_h / h))
238
+ new_w = (new_w // 32) * 32
239
+
240
+ return image.resize((new_w, new_h))
241
+
242
+ @spaces.GPU(duration=120)
243
+ def generate_image(
244
+ self,
245
+ prompt: str,
246
+ negative_prompt: Optional[str] = None,
247
+ height: Optional[int] = None,
248
+ width: Optional[int] = None,
249
+ num_inference_steps: int = 36,
250
+ guidance_scale: float = 3.0,
251
+ seed: int = 42
252
+ ) -> Image.Image:
253
+ if not self.pipeline:
254
+ self.pipeline = self._init_pipeline()
255
+ height = height or self.image_size
256
+ width = width or self.image_size
257
+
258
+ inputs = self._prepare_text_inputs(prompt, negative_prompt)
259
+ num_queries = self.conditioner.config.num_queries
260
+ embeds = self._process_inputs(inputs, num_queries)
261
+
262
+ generator = torch.Generator(device=self.device).manual_seed(seed)
263
+
264
+ image = self.pipeline(
265
+ prompt_embeds=embeds["prompt_embeds"].to(self.device),
266
+ pooled_prompt_embeds=embeds["pooled_prompt_embeds"].to(self.device),
267
+ negative_prompt_embeds=embeds["negative_prompt_embeds"].to(self.device),
268
+ negative_pooled_prompt_embeds=embeds["negative_pooled_prompt_embeds"].to(self.device),
269
+ height=height,
270
+ width=width,
271
+ num_inference_steps=num_inference_steps,
272
+ guidance_scale=guidance_scale,
273
+ generator=generator
274
+ ).images
275
+
276
+ return image
277
+
278
+ @spaces.GPU(duration=120)
279
+ def edit_image(
280
+ self,
281
+ image: Image.Image,
282
+ prompt: str,
283
+ negative_prompt: Optional[str] = None,
284
+ height: Optional[int] = None,
285
+ width: Optional[int] = None,
286
+ num_inference_steps: int = 36,
287
+ guidance_scale: float = 3.0,
288
+ seed: int = 42
289
+ ) -> Image.Image:
290
+ if not self.pipeline:
291
+ self.pipeline = self._init_pipeline()
292
+ original_size = image.size
293
+ image = self._resize_image(image, self.image_size)
294
+ height = height or original_size[1]
295
+ width = width or original_size[0]
296
+
297
+ inputs = self._prepare_image_inputs(image, prompt, negative_prompt)
298
+ num_queries = self.conditioner.config.num_queries
299
+ embeds = self._process_inputs(inputs, num_queries)
300
+
301
+ generator = torch.Generator(device=self.device).manual_seed(seed)
302
+
303
+ edited_image = self.pipeline(
304
+ image=image,
305
+ prompt_embeds=embeds["prompt_embeds"].to(self.device),
306
+ pooled_prompt_embeds=embeds["pooled_prompt_embeds"].to(self.device),
307
+ negative_prompt_embeds=embeds["negative_prompt_embeds"].to(self.device),
308
+ negative_pooled_prompt_embeds=embeds["negative_pooled_prompt_embeds"].to(self.device),
309
+ height=height,
310
+ width=width,
311
+ num_inference_steps=num_inference_steps,
312
+ guidance_scale=guidance_scale,
313
+ generator=generator
314
+ ).images
315
+
316
+ return edited_image
317
+
318
+ @spaces.GPU(duration=120)
319
+ def understand_image(
320
+ self,
321
+ image: Image.Image,
322
+ prompt: str,
323
+ max_new_tokens: int = 512
324
+ ) -> str:
325
+ """
326
+ Understand the content of an image and answer questions about it.
327
+
328
+ Args:
329
+ image: Input image to understand
330
+ prompt: Question or instruction about the image
331
+ max_new_tokens: Maximum number of tokens to generate
332
+
333
+ Returns:
334
+ str: The model's response to the prompt
335
+ """
336
+ # Prepare messages in Qwen-VL format
337
+ if not self.pipeline:
338
+ self.pipeline = self._init_pipeline()
339
+ messages = [
340
+ {
341
+ "role": "user",
342
+ "content": [
343
+ {"type": "image", "image": image},
344
+ {"type": "text", "text": prompt},
345
+ ],
346
+ },
347
+ ]
348
+
349
+ # Apply chat template
350
+ text = self.processor.apply_chat_template(
351
+ messages,
352
+ tokenize=False,
353
+ add_generation_prompt=True
354
+ )
355
+
356
+ # Calculate appropriate image size for processing
357
+ min_pixels = max_pixels = int(image.height * 28 / 32 * image.width * 28 / 32)
358
+
359
+ # Process inputs
360
+ inputs = self.processor(
361
+ text=[text],
362
+ images=[image],
363
+ min_pixels=min_pixels,
364
+ max_pixels=max_pixels,
365
+ padding=True,
366
+ return_tensors="pt"
367
+ ).to(self.device)
368
+
369
+ # Generate response
370
+ generated_ids = self.lmm.generate(
371
+ **inputs,
372
+ max_new_tokens=max_new_tokens
373
+ )
374
+
375
+ # Trim input tokens from output
376
+ generated_ids_trimmed = [
377
+ out_ids[len(in_ids):]
378
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
379
+ ]
380
+
381
+ # Decode the response
382
+ output_text = self.processor.batch_decode(
383
+ generated_ids_trimmed,
384
+ skip_special_tokens=True,
385
+ clean_up_tokenization_spaces=False
386
+ )[0]
387
+
388
+ return output_text
389
+