#!/usr/bin/env python3 # -*- coding: utf-8 -*- import torch import math from PIL import Image from typing import List, Optional from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL, BitsAndBytesConfig from unipicv2.pipeline_stable_diffusion_3_kontext import StableDiffusion3KontextPipeline from unipicv2.transformer_sd3_kontext import SD3Transformer2DKontextModel from unipicv2.stable_diffusion_3_conditioner import StableDiffusion3Conditioner import spaces class UniPicV2Inferencer: def __init__( self, model_path: str, qwen_vl_path: str, quant: str = "fp16", # {"int4", "fp16"} image_size: int = 512, default_negative_prompt: str = "blurry, low quality, low resolution, distorted, deformed, broken content, missing parts, damaged details, artifacts, glitch, noise, pixelated, grainy, compression artifacts, bad composition, wrong proportion, incomplete editing, unfinished, unedited areas." ): self.model_path = model_path self.qwen_vl_path = qwen_vl_path self.quant = quant self.image_size = image_size self.default_negative_prompt = default_negative_prompt self.device = torch.device("cuda") self.pipeline = None #self._init_pipeline() def _init_pipeline(self) -> StableDiffusion3KontextPipeline: print("Initializing UniPicV2 pipeline...") # ===== 1. Initialize BNB Config ===== bnb4 = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, ) # ===== 2. Load SD3 Transformer ===== if self.quant == "int4": transformer = SD3Transformer2DKontextModel.from_pretrained( self.model_path, subfolder="transformer", quantization_config=bnb4, device_map="auto", low_cpu_mem_usage=True ) else: transformer = SD3Transformer2DKontextModel.from_pretrained( self.model_path, subfolder="transformer", torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True ) # ===== 3. Load VAE ===== vae = AutoencoderKL.from_pretrained( self.model_path, subfolder="vae", torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True ).to(self.device) # ===== 4. Load Qwen2.5-VL (LMM) ===== try: self.lmm = Qwen2_5_VLForConditionalGeneration.from_pretrained( self.qwen_vl_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="auto", ).to(self.device) print("**"*20) except Exception: self.lmm = Qwen2_5_VLForConditionalGeneration.from_pretrained( self.qwen_vl_path, torch_dtype=torch.bfloat16, attn_implementation="sdpa", device_map="auto", ).to(self.device) # ===== 5. Load Processor ===== self.processor = Qwen2_5_VLProcessor.from_pretrained(self.qwen_vl_path, use_fast=False) if hasattr(self.processor, "chat_template") and self.processor.chat_template: self.processor.chat_template = self.processor.chat_template.replace( "{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}", "" ) # ===== 6. Load Conditioner ===== self.conditioner = StableDiffusion3Conditioner.from_pretrained( self.model_path, subfolder="conditioner", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True ).to(self.device) # ===== 7. Load Scheduler ===== scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( self.model_path, subfolder="scheduler" ) # ===== 8. Create Pipeline ===== pipeline = StableDiffusion3KontextPipeline( transformer=transformer, vae=vae, text_encoder=None, tokenizer=None, text_encoder_2=None, tokenizer_2=None, text_encoder_3=None, tokenizer_3=None, scheduler=scheduler ) try: pipeline.enable_vae_slicing() pipeline.enable_vae_tiling() pipeline.enable_model_cpu_offload() except Exception: print("Note: Could not enable all memory-saving features") print("Pipeline initialization complete!") return pipeline def _prepare_text_inputs(self, prompt: str, negative_prompt: str = None): messages = [ [{"role": "user", "content": [{"type": "text", "text": prompt}]}], [{"role": "user", "content": [{"type": "text", "text": negative_prompt}]}] ] texts = [ self.processor.apply_chat_template(m, tokenize=False, add_generation_prompt=True) for m in messages ] inputs = self.processor( text=texts, images=None, padding=True, return_tensors="pt" ) return inputs def _prepare_image_inputs(self, image: Image.Image, prompt: str, negative_prompt: str = None): negative_prompt = negative_prompt or self.default_negative_prompt messages = [ [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}], [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": negative_prompt}]}] ] texts = [ self.processor.apply_chat_template(m, tokenize=False, add_generation_prompt=True) for m in messages ] min_pixels = max_pixels = int(image.height * 28 / 32 * image.width * 28 / 32) inputs = self.processor( text=texts, images=[image] * 2, min_pixels=min_pixels, max_pixels=max_pixels, padding=True, return_tensors="pt" ) return inputs def _process_inputs(self, inputs: dict, num_queries: int): # Ensure all tensors are on the correct device inputs = { k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items() } input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] # Pad with meta queries pad_ids = torch.zeros((input_ids.size(0), num_queries), dtype=input_ids.dtype, device=self.device) pad_mask = torch.ones((attention_mask.size(0), num_queries), dtype=attention_mask.dtype, device=self.device) input_ids = torch.cat([input_ids, pad_ids], dim=1) attention_mask = torch.cat([attention_mask, pad_mask], dim=1) # Get input embeddings # 获取 embedding 权重所在设备 embed_device = self.lmm.get_input_embeddings().weight.device # 确保 input_ids 在同一设备 input_ids = input_ids.to(embed_device) inputs_embeds = self.lmm.get_input_embeddings()(input_ids) # Ensure meta queries are on correct device self.conditioner.meta_queries.data = self.conditioner.meta_queries.data.to(self.device) inputs_embeds[:, -num_queries:] = self.conditioner.meta_queries[None].expand(2, -1, -1) # Handle image embeddings if present if "pixel_values" in inputs: image_embeds = self.lmm.visual( inputs["pixel_values"].to(self.device), grid_thw=inputs["image_grid_thw"].to(self.device) ) image_token_id = self.processor.tokenizer.convert_tokens_to_ids('<|image_pad|>') mask_img = (input_ids == image_token_id) inputs_embeds[mask_img] = image_embeds # Forward through LMM if hasattr(self.lmm.model, "rope_deltas"): self.lmm.model.rope_deltas = None #model_device = self.lmm.model.embed_tokens.weight.device # 强制将所有 tensor 输入搬到这个设备 for k, v in inputs.items(): if isinstance(v, torch.Tensor): inputs[k] = v.to(self.device) outputs = self.lmm.model( inputs_embeds=inputs_embeds.to(self.device), attention_mask=attention_mask.to(self.device), image_grid_thw=inputs.get("image_grid_thw", None), use_cache=False ) hidden_states = outputs.last_hidden_state[:, -num_queries:] hidden_states = hidden_states.to(self.device) # Get prompt embeds prompt_embeds, pooled_prompt_embeds = self.conditioner(hidden_states) return { "prompt_embeds": prompt_embeds[:1], "pooled_prompt_embeds": pooled_prompt_embeds[:1], "negative_prompt_embeds": prompt_embeds[1:], "negative_pooled_prompt_embeds": pooled_prompt_embeds[1:] } def _resize_image(self, image: Image.Image, size: int) -> Image.Image: w, h = image.size if w >= h: new_w = size new_h = int(h * (new_w / w)) new_h = (new_h // 32) * 32 else: new_h = size new_w = int(w * (new_h / h)) new_w = (new_w // 32) * 32 return image.resize((new_w, new_h)) @spaces.GPU(duration=120) def generate_image( self, prompt: str, negative_prompt: Optional[str] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 36, guidance_scale: float = 3.0, seed: int = 42 ) -> Image.Image: if not self.pipeline: self.pipeline = self._init_pipeline() height = height or self.image_size width = width or self.image_size prompt = "Generate an image: " + prompt negative_prompt = "Generate an image: " + negative_prompt if negative_prompt else "" #self.default_negative_prompt inputs = self._prepare_text_inputs(prompt, negative_prompt) num_queries = self.conditioner.config.num_queries embeds = self._process_inputs(inputs, num_queries) generator = torch.Generator(device=self.device).manual_seed(seed) image = self.pipeline( prompt_embeds=embeds["prompt_embeds"].to(self.device), pooled_prompt_embeds=embeds["pooled_prompt_embeds"].to(self.device), negative_prompt_embeds=embeds["negative_prompt_embeds"].to(self.device), negative_pooled_prompt_embeds=embeds["negative_pooled_prompt_embeds"].to(self.device), height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator ).images return image @spaces.GPU(duration=120) def edit_image( self, image: Image.Image, prompt: str, negative_prompt: Optional[str] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 36, guidance_scale: float = 3.0, seed: int = 42 ) -> Image.Image: if image.mode in ["RGBA", "LA"] or image.mode.startswith("A"): image = image.convert("RGB") if not self.pipeline: self.pipeline = self._init_pipeline() original_size = image.size image = self._resize_image(image, self.image_size) if height is None or width is None: height, width = image.height, image.width inputs = self._prepare_image_inputs(image, prompt, negative_prompt) num_queries = self.conditioner.config.num_queries embeds = self._process_inputs(inputs, num_queries) generator = torch.Generator(device=self.device).manual_seed(seed) edited_image = self.pipeline( image=image, prompt_embeds=embeds["prompt_embeds"].to(self.device), pooled_prompt_embeds=embeds["pooled_prompt_embeds"].to(self.device), negative_prompt_embeds=embeds["negative_prompt_embeds"].to(self.device), negative_pooled_prompt_embeds=embeds["negative_pooled_prompt_embeds"].to(self.device), height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator ).images return edited_image @spaces.GPU(duration=120) def understand_image( self, image: Image.Image, prompt: str, max_new_tokens: int = 512 ) -> str: """ Understand the content of an image and answer questions about it. Args: image: Input image to understand prompt: Question or instruction about the image max_new_tokens: Maximum number of tokens to generate Returns: str: The model's response to the prompt """ # Prepare messages in Qwen-VL format if not self.pipeline: self.pipeline = self._init_pipeline() messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt}, ], }, ] # Apply chat template text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Calculate appropriate image size for processing min_pixels = max_pixels = int(image.height * 28 / 32 * image.width * 28 / 32) # Process inputs inputs = self.processor( text=[text], images=[image], min_pixels=min_pixels, max_pixels=max_pixels, padding=True, return_tensors="pt" ).to(self.device) # Generate response generated_ids = self.lmm.generate( **inputs, max_new_tokens=max_new_tokens ) # Trim input tokens from output generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] # Decode the response output_text = self.processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] return output_text