UniPic2-Metaquery / inferencer.py
yichenchenchen's picture
Update inferencer.py
6ff6267 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
import math
from PIL import Image
from typing import List, Optional
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor
from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL, BitsAndBytesConfig
from unipicv2.pipeline_stable_diffusion_3_kontext import StableDiffusion3KontextPipeline
from unipicv2.transformer_sd3_kontext import SD3Transformer2DKontextModel
from unipicv2.stable_diffusion_3_conditioner import StableDiffusion3Conditioner
import spaces
class UniPicV2Inferencer:
def __init__(
self,
model_path: str,
qwen_vl_path: str,
quant: str = "fp16", # {"int4", "fp16"}
image_size: int = 512,
default_negative_prompt: str = "blurry, low quality, low resolution, distorted, deformed, broken content, missing parts, damaged details, artifacts, glitch, noise, pixelated, grainy, compression artifacts, bad composition, wrong proportion, incomplete editing, unfinished, unedited areas."
):
self.model_path = model_path
self.qwen_vl_path = qwen_vl_path
self.quant = quant
self.image_size = image_size
self.default_negative_prompt = default_negative_prompt
self.device = torch.device("cuda")
self.pipeline = None #self._init_pipeline()
def _init_pipeline(self) -> StableDiffusion3KontextPipeline:
print("Initializing UniPicV2 pipeline...")
# ===== 1. Initialize BNB Config =====
bnb4 = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
# ===== 2. Load SD3 Transformer =====
if self.quant == "int4":
transformer = SD3Transformer2DKontextModel.from_pretrained(
self.model_path, subfolder="transformer",
quantization_config=bnb4, device_map="auto", low_cpu_mem_usage=True
)
else:
transformer = SD3Transformer2DKontextModel.from_pretrained(
self.model_path, subfolder="transformer",
torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True
)
# ===== 3. Load VAE =====
vae = AutoencoderKL.from_pretrained(
self.model_path, subfolder="vae",
torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True
).to(self.device)
# ===== 4. Load Qwen2.5-VL (LMM) =====
try:
self.lmm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
self.qwen_vl_path,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto",
).to(self.device)
print("**"*20)
except Exception:
self.lmm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
self.qwen_vl_path,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
device_map="auto",
).to(self.device)
# ===== 5. Load Processor =====
self.processor = Qwen2_5_VLProcessor.from_pretrained(self.qwen_vl_path, use_fast=False)
if hasattr(self.processor, "chat_template") and self.processor.chat_template:
self.processor.chat_template = self.processor.chat_template.replace(
"{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}",
""
)
# ===== 6. Load Conditioner =====
self.conditioner = StableDiffusion3Conditioner.from_pretrained(
self.model_path, subfolder="conditioner",
torch_dtype=torch.bfloat16, low_cpu_mem_usage=True
).to(self.device)
# ===== 7. Load Scheduler =====
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
self.model_path, subfolder="scheduler"
)
# ===== 8. Create Pipeline =====
pipeline = StableDiffusion3KontextPipeline(
transformer=transformer,
vae=vae,
text_encoder=None,
tokenizer=None,
text_encoder_2=None,
tokenizer_2=None,
text_encoder_3=None,
tokenizer_3=None,
scheduler=scheduler
)
try:
pipeline.enable_vae_slicing()
pipeline.enable_vae_tiling()
pipeline.enable_model_cpu_offload()
except Exception:
print("Note: Could not enable all memory-saving features")
print("Pipeline initialization complete!")
return pipeline
def _prepare_text_inputs(self, prompt: str, negative_prompt: str = None):
messages = [
[{"role": "user", "content": [{"type": "text", "text": prompt}]}],
[{"role": "user", "content": [{"type": "text", "text": negative_prompt}]}]
]
texts = [
self.processor.apply_chat_template(m, tokenize=False, add_generation_prompt=True)
for m in messages
]
inputs = self.processor(
text=texts,
images=None,
padding=True,
return_tensors="pt"
)
return inputs
def _prepare_image_inputs(self, image: Image.Image, prompt: str, negative_prompt: str = None):
negative_prompt = negative_prompt or self.default_negative_prompt
messages = [
[{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}],
[{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": negative_prompt}]}]
]
texts = [
self.processor.apply_chat_template(m, tokenize=False, add_generation_prompt=True)
for m in messages
]
min_pixels = max_pixels = int(image.height * 28 / 32 * image.width * 28 / 32)
inputs = self.processor(
text=texts,
images=[image] * 2,
min_pixels=min_pixels,
max_pixels=max_pixels,
padding=True,
return_tensors="pt"
)
return inputs
def _process_inputs(self, inputs: dict, num_queries: int):
# Ensure all tensors are on the correct device
inputs = {
k: v.to(self.device) if isinstance(v, torch.Tensor) else v
for k, v in inputs.items()
}
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
# Pad with meta queries
pad_ids = torch.zeros((input_ids.size(0), num_queries),
dtype=input_ids.dtype, device=self.device)
pad_mask = torch.ones((attention_mask.size(0), num_queries),
dtype=attention_mask.dtype, device=self.device)
input_ids = torch.cat([input_ids, pad_ids], dim=1)
attention_mask = torch.cat([attention_mask, pad_mask], dim=1)
# Get input embeddings
# 获取 embedding 权重所在设备
embed_device = self.lmm.get_input_embeddings().weight.device
# 确保 input_ids 在同一设备
input_ids = input_ids.to(embed_device)
inputs_embeds = self.lmm.get_input_embeddings()(input_ids)
# Ensure meta queries are on correct device
self.conditioner.meta_queries.data = self.conditioner.meta_queries.data.to(self.device)
inputs_embeds[:, -num_queries:] = self.conditioner.meta_queries[None].expand(2, -1, -1)
# Handle image embeddings if present
if "pixel_values" in inputs:
image_embeds = self.lmm.visual(
inputs["pixel_values"].to(self.device),
grid_thw=inputs["image_grid_thw"].to(self.device)
)
image_token_id = self.processor.tokenizer.convert_tokens_to_ids('<|image_pad|>')
mask_img = (input_ids == image_token_id)
inputs_embeds[mask_img] = image_embeds
# Forward through LMM
if hasattr(self.lmm.model, "rope_deltas"):
self.lmm.model.rope_deltas = None
#model_device = self.lmm.model.embed_tokens.weight.device
# 强制将所有 tensor 输入搬到这个设备
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
inputs[k] = v.to(self.device)
outputs = self.lmm.model(
inputs_embeds=inputs_embeds.to(self.device),
attention_mask=attention_mask.to(self.device),
image_grid_thw=inputs.get("image_grid_thw", None),
use_cache=False
)
hidden_states = outputs.last_hidden_state[:, -num_queries:]
hidden_states = hidden_states.to(self.device)
# Get prompt embeds
prompt_embeds, pooled_prompt_embeds = self.conditioner(hidden_states)
return {
"prompt_embeds": prompt_embeds[:1],
"pooled_prompt_embeds": pooled_prompt_embeds[:1],
"negative_prompt_embeds": prompt_embeds[1:],
"negative_pooled_prompt_embeds": pooled_prompt_embeds[1:]
}
def _resize_image(self, image: Image.Image, size: int) -> Image.Image:
w, h = image.size
if w >= h:
new_w = size
new_h = int(h * (new_w / w))
new_h = (new_h // 32) * 32
else:
new_h = size
new_w = int(w * (new_h / h))
new_w = (new_w // 32) * 32
return image.resize((new_w, new_h))
@spaces.GPU(duration=120)
def generate_image(
self,
prompt: str,
negative_prompt: Optional[str] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 36,
guidance_scale: float = 3.0,
seed: int = 42
) -> Image.Image:
if not self.pipeline:
self.pipeline = self._init_pipeline()
height = height or self.image_size
width = width or self.image_size
prompt = "Generate an image: " + prompt
negative_prompt = "Generate an image: " + negative_prompt if negative_prompt else "" #self.default_negative_prompt
inputs = self._prepare_text_inputs(prompt, negative_prompt)
num_queries = self.conditioner.config.num_queries
embeds = self._process_inputs(inputs, num_queries)
generator = torch.Generator(device=self.device).manual_seed(seed)
image = self.pipeline(
prompt_embeds=embeds["prompt_embeds"].to(self.device),
pooled_prompt_embeds=embeds["pooled_prompt_embeds"].to(self.device),
negative_prompt_embeds=embeds["negative_prompt_embeds"].to(self.device),
negative_pooled_prompt_embeds=embeds["negative_pooled_prompt_embeds"].to(self.device),
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator
).images
return image
@spaces.GPU(duration=120)
def edit_image(
self,
image: Image.Image,
prompt: str,
negative_prompt: Optional[str] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 36,
guidance_scale: float = 3.0,
seed: int = 42
) -> Image.Image:
if image.mode in ["RGBA", "LA"] or image.mode.startswith("A"):
image = image.convert("RGB")
if not self.pipeline:
self.pipeline = self._init_pipeline()
original_size = image.size
image = self._resize_image(image, self.image_size)
if height is None or width is None:
height, width = image.height, image.width
inputs = self._prepare_image_inputs(image, prompt, negative_prompt)
num_queries = self.conditioner.config.num_queries
embeds = self._process_inputs(inputs, num_queries)
generator = torch.Generator(device=self.device).manual_seed(seed)
edited_image = self.pipeline(
image=image,
prompt_embeds=embeds["prompt_embeds"].to(self.device),
pooled_prompt_embeds=embeds["pooled_prompt_embeds"].to(self.device),
negative_prompt_embeds=embeds["negative_prompt_embeds"].to(self.device),
negative_pooled_prompt_embeds=embeds["negative_pooled_prompt_embeds"].to(self.device),
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator
).images
return edited_image
@spaces.GPU(duration=120)
def understand_image(
self,
image: Image.Image,
prompt: str,
max_new_tokens: int = 512
) -> str:
"""
Understand the content of an image and answer questions about it.
Args:
image: Input image to understand
prompt: Question or instruction about the image
max_new_tokens: Maximum number of tokens to generate
Returns:
str: The model's response to the prompt
"""
# Prepare messages in Qwen-VL format
if not self.pipeline:
self.pipeline = self._init_pipeline()
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
},
]
# Apply chat template
text = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Calculate appropriate image size for processing
min_pixels = max_pixels = int(image.height * 28 / 32 * image.width * 28 / 32)
# Process inputs
inputs = self.processor(
text=[text],
images=[image],
min_pixels=min_pixels,
max_pixels=max_pixels,
padding=True,
return_tensors="pt"
).to(self.device)
# Generate response
generated_ids = self.lmm.generate(
**inputs,
max_new_tokens=max_new_tokens
)
# Trim input tokens from output
generated_ids_trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
# Decode the response
output_text = self.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
return output_text