Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
import torch | |
import math | |
from PIL import Image | |
from typing import List, Optional | |
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor | |
from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL, BitsAndBytesConfig | |
from unipicv2.pipeline_stable_diffusion_3_kontext import StableDiffusion3KontextPipeline | |
from unipicv2.transformer_sd3_kontext import SD3Transformer2DKontextModel | |
from unipicv2.stable_diffusion_3_conditioner import StableDiffusion3Conditioner | |
import spaces | |
class UniPicV2Inferencer: | |
def __init__( | |
self, | |
model_path: str, | |
qwen_vl_path: str, | |
quant: str = "fp16", # {"int4", "fp16"} | |
image_size: int = 512, | |
default_negative_prompt: str = "blurry, low quality, low resolution, distorted, deformed, broken content, missing parts, damaged details, artifacts, glitch, noise, pixelated, grainy, compression artifacts, bad composition, wrong proportion, incomplete editing, unfinished, unedited areas." | |
): | |
self.model_path = model_path | |
self.qwen_vl_path = qwen_vl_path | |
self.quant = quant | |
self.image_size = image_size | |
self.default_negative_prompt = default_negative_prompt | |
self.device = torch.device("cuda") | |
self.pipeline = None #self._init_pipeline() | |
def _init_pipeline(self) -> StableDiffusion3KontextPipeline: | |
print("Initializing UniPicV2 pipeline...") | |
# ===== 1. Initialize BNB Config ===== | |
bnb4 = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.float16, | |
) | |
# ===== 2. Load SD3 Transformer ===== | |
if self.quant == "int4": | |
transformer = SD3Transformer2DKontextModel.from_pretrained( | |
self.model_path, subfolder="transformer", | |
quantization_config=bnb4, device_map="auto", low_cpu_mem_usage=True | |
) | |
else: | |
transformer = SD3Transformer2DKontextModel.from_pretrained( | |
self.model_path, subfolder="transformer", | |
torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True | |
) | |
# ===== 3. Load VAE ===== | |
vae = AutoencoderKL.from_pretrained( | |
self.model_path, subfolder="vae", | |
torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True | |
).to(self.device) | |
# ===== 4. Load Qwen2.5-VL (LMM) ===== | |
try: | |
self.lmm = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
self.qwen_vl_path, | |
torch_dtype=torch.bfloat16, | |
attn_implementation="flash_attention_2", | |
device_map="auto", | |
).to(self.device) | |
print("**"*20) | |
except Exception: | |
self.lmm = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
self.qwen_vl_path, | |
torch_dtype=torch.bfloat16, | |
attn_implementation="sdpa", | |
device_map="auto", | |
).to(self.device) | |
# ===== 5. Load Processor ===== | |
self.processor = Qwen2_5_VLProcessor.from_pretrained(self.qwen_vl_path, use_fast=False) | |
if hasattr(self.processor, "chat_template") and self.processor.chat_template: | |
self.processor.chat_template = self.processor.chat_template.replace( | |
"{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}", | |
"" | |
) | |
# ===== 6. Load Conditioner ===== | |
self.conditioner = StableDiffusion3Conditioner.from_pretrained( | |
self.model_path, subfolder="conditioner", | |
torch_dtype=torch.bfloat16, low_cpu_mem_usage=True | |
).to(self.device) | |
# ===== 7. Load Scheduler ===== | |
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | |
self.model_path, subfolder="scheduler" | |
) | |
# ===== 8. Create Pipeline ===== | |
pipeline = StableDiffusion3KontextPipeline( | |
transformer=transformer, | |
vae=vae, | |
text_encoder=None, | |
tokenizer=None, | |
text_encoder_2=None, | |
tokenizer_2=None, | |
text_encoder_3=None, | |
tokenizer_3=None, | |
scheduler=scheduler | |
) | |
try: | |
pipeline.enable_vae_slicing() | |
pipeline.enable_vae_tiling() | |
pipeline.enable_model_cpu_offload() | |
except Exception: | |
print("Note: Could not enable all memory-saving features") | |
print("Pipeline initialization complete!") | |
return pipeline | |
def _prepare_text_inputs(self, prompt: str, negative_prompt: str = None): | |
messages = [ | |
[{"role": "user", "content": [{"type": "text", "text": prompt}]}], | |
[{"role": "user", "content": [{"type": "text", "text": negative_prompt}]}] | |
] | |
texts = [ | |
self.processor.apply_chat_template(m, tokenize=False, add_generation_prompt=True) | |
for m in messages | |
] | |
inputs = self.processor( | |
text=texts, | |
images=None, | |
padding=True, | |
return_tensors="pt" | |
) | |
return inputs | |
def _prepare_image_inputs(self, image: Image.Image, prompt: str, negative_prompt: str = None): | |
negative_prompt = negative_prompt or self.default_negative_prompt | |
messages = [ | |
[{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}], | |
[{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": negative_prompt}]}] | |
] | |
texts = [ | |
self.processor.apply_chat_template(m, tokenize=False, add_generation_prompt=True) | |
for m in messages | |
] | |
min_pixels = max_pixels = int(image.height * 28 / 32 * image.width * 28 / 32) | |
inputs = self.processor( | |
text=texts, | |
images=[image] * 2, | |
min_pixels=min_pixels, | |
max_pixels=max_pixels, | |
padding=True, | |
return_tensors="pt" | |
) | |
return inputs | |
def _process_inputs(self, inputs: dict, num_queries: int): | |
# Ensure all tensors are on the correct device | |
inputs = { | |
k: v.to(self.device) if isinstance(v, torch.Tensor) else v | |
for k, v in inputs.items() | |
} | |
input_ids = inputs["input_ids"] | |
attention_mask = inputs["attention_mask"] | |
# Pad with meta queries | |
pad_ids = torch.zeros((input_ids.size(0), num_queries), | |
dtype=input_ids.dtype, device=self.device) | |
pad_mask = torch.ones((attention_mask.size(0), num_queries), | |
dtype=attention_mask.dtype, device=self.device) | |
input_ids = torch.cat([input_ids, pad_ids], dim=1) | |
attention_mask = torch.cat([attention_mask, pad_mask], dim=1) | |
# Get input embeddings | |
# 获取 embedding 权重所在设备 | |
embed_device = self.lmm.get_input_embeddings().weight.device | |
# 确保 input_ids 在同一设备 | |
input_ids = input_ids.to(embed_device) | |
inputs_embeds = self.lmm.get_input_embeddings()(input_ids) | |
# Ensure meta queries are on correct device | |
self.conditioner.meta_queries.data = self.conditioner.meta_queries.data.to(self.device) | |
inputs_embeds[:, -num_queries:] = self.conditioner.meta_queries[None].expand(2, -1, -1) | |
# Handle image embeddings if present | |
if "pixel_values" in inputs: | |
image_embeds = self.lmm.visual( | |
inputs["pixel_values"].to(self.device), | |
grid_thw=inputs["image_grid_thw"].to(self.device) | |
) | |
image_token_id = self.processor.tokenizer.convert_tokens_to_ids('<|image_pad|>') | |
mask_img = (input_ids == image_token_id) | |
inputs_embeds[mask_img] = image_embeds | |
# Forward through LMM | |
if hasattr(self.lmm.model, "rope_deltas"): | |
self.lmm.model.rope_deltas = None | |
#model_device = self.lmm.model.embed_tokens.weight.device | |
# 强制将所有 tensor 输入搬到这个设备 | |
for k, v in inputs.items(): | |
if isinstance(v, torch.Tensor): | |
inputs[k] = v.to(self.device) | |
outputs = self.lmm.model( | |
inputs_embeds=inputs_embeds.to(self.device), | |
attention_mask=attention_mask.to(self.device), | |
image_grid_thw=inputs.get("image_grid_thw", None), | |
use_cache=False | |
) | |
hidden_states = outputs.last_hidden_state[:, -num_queries:] | |
hidden_states = hidden_states.to(self.device) | |
# Get prompt embeds | |
prompt_embeds, pooled_prompt_embeds = self.conditioner(hidden_states) | |
return { | |
"prompt_embeds": prompt_embeds[:1], | |
"pooled_prompt_embeds": pooled_prompt_embeds[:1], | |
"negative_prompt_embeds": prompt_embeds[1:], | |
"negative_pooled_prompt_embeds": pooled_prompt_embeds[1:] | |
} | |
def _resize_image(self, image: Image.Image, size: int) -> Image.Image: | |
w, h = image.size | |
if w >= h: | |
new_w = size | |
new_h = int(h * (new_w / w)) | |
new_h = (new_h // 32) * 32 | |
else: | |
new_h = size | |
new_w = int(w * (new_h / h)) | |
new_w = (new_w // 32) * 32 | |
return image.resize((new_w, new_h)) | |
def generate_image( | |
self, | |
prompt: str, | |
negative_prompt: Optional[str] = None, | |
height: Optional[int] = None, | |
width: Optional[int] = None, | |
num_inference_steps: int = 36, | |
guidance_scale: float = 3.0, | |
seed: int = 42 | |
) -> Image.Image: | |
if not self.pipeline: | |
self.pipeline = self._init_pipeline() | |
height = height or self.image_size | |
width = width or self.image_size | |
prompt = "Generate an image: " + prompt | |
negative_prompt = "Generate an image: " + negative_prompt if negative_prompt else "" #self.default_negative_prompt | |
inputs = self._prepare_text_inputs(prompt, negative_prompt) | |
num_queries = self.conditioner.config.num_queries | |
embeds = self._process_inputs(inputs, num_queries) | |
generator = torch.Generator(device=self.device).manual_seed(seed) | |
image = self.pipeline( | |
prompt_embeds=embeds["prompt_embeds"].to(self.device), | |
pooled_prompt_embeds=embeds["pooled_prompt_embeds"].to(self.device), | |
negative_prompt_embeds=embeds["negative_prompt_embeds"].to(self.device), | |
negative_pooled_prompt_embeds=embeds["negative_pooled_prompt_embeds"].to(self.device), | |
height=height, | |
width=width, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
generator=generator | |
).images | |
return image | |
def edit_image( | |
self, | |
image: Image.Image, | |
prompt: str, | |
negative_prompt: Optional[str] = None, | |
height: Optional[int] = None, | |
width: Optional[int] = None, | |
num_inference_steps: int = 36, | |
guidance_scale: float = 3.0, | |
seed: int = 42 | |
) -> Image.Image: | |
if image.mode in ["RGBA", "LA"] or image.mode.startswith("A"): | |
image = image.convert("RGB") | |
if not self.pipeline: | |
self.pipeline = self._init_pipeline() | |
original_size = image.size | |
image = self._resize_image(image, self.image_size) | |
if height is None or width is None: | |
height, width = image.height, image.width | |
inputs = self._prepare_image_inputs(image, prompt, negative_prompt) | |
num_queries = self.conditioner.config.num_queries | |
embeds = self._process_inputs(inputs, num_queries) | |
generator = torch.Generator(device=self.device).manual_seed(seed) | |
edited_image = self.pipeline( | |
image=image, | |
prompt_embeds=embeds["prompt_embeds"].to(self.device), | |
pooled_prompt_embeds=embeds["pooled_prompt_embeds"].to(self.device), | |
negative_prompt_embeds=embeds["negative_prompt_embeds"].to(self.device), | |
negative_pooled_prompt_embeds=embeds["negative_pooled_prompt_embeds"].to(self.device), | |
height=height, | |
width=width, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
generator=generator | |
).images | |
return edited_image | |
def understand_image( | |
self, | |
image: Image.Image, | |
prompt: str, | |
max_new_tokens: int = 512 | |
) -> str: | |
""" | |
Understand the content of an image and answer questions about it. | |
Args: | |
image: Input image to understand | |
prompt: Question or instruction about the image | |
max_new_tokens: Maximum number of tokens to generate | |
Returns: | |
str: The model's response to the prompt | |
""" | |
# Prepare messages in Qwen-VL format | |
if not self.pipeline: | |
self.pipeline = self._init_pipeline() | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "image": image}, | |
{"type": "text", "text": prompt}, | |
], | |
}, | |
] | |
# Apply chat template | |
text = self.processor.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
# Calculate appropriate image size for processing | |
min_pixels = max_pixels = int(image.height * 28 / 32 * image.width * 28 / 32) | |
# Process inputs | |
inputs = self.processor( | |
text=[text], | |
images=[image], | |
min_pixels=min_pixels, | |
max_pixels=max_pixels, | |
padding=True, | |
return_tensors="pt" | |
).to(self.device) | |
# Generate response | |
generated_ids = self.lmm.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens | |
) | |
# Trim input tokens from output | |
generated_ids_trimmed = [ | |
out_ids[len(in_ids):] | |
for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
] | |
# Decode the response | |
output_text = self.processor.batch_decode( | |
generated_ids_trimmed, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=False | |
)[0] | |
return output_text | |