import os from typing import List import torch from PIL import Image from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from safetensors.torch import load_file from nested_attention_processor import AttnProcessor, NestedAttnProcessor from utils import get_generator from resampler import Resampler def add_special_token_to_tokenizer( pipe, placeholder_token, initializer_token ): num_added_tokens1 = pipe.tokenizer.add_tokens([placeholder_token]) num_added_tokens2 = pipe.tokenizer_2.add_tokens([placeholder_token]) if num_added_tokens1 != 1 or num_added_tokens2 != 1: raise ValueError("Failed to add placeholder token to tokenizer") token_ids1 = pipe.tokenizer.encode(initializer_token, add_special_tokens=False) token_ids2 = pipe.tokenizer_2.encode(initializer_token, add_special_tokens=False) if len(token_ids1) > 1 or len(token_ids2) > 1: raise ValueError("The initializer token must be a single token.") initializer_token_id1 = token_ids1[0] initializer_token_id2 = token_ids2[0] placeholder_token_ids1 = pipe.tokenizer.convert_tokens_to_ids([placeholder_token]) placeholder_token_ids2 = pipe.tokenizer_2.convert_tokens_to_ids([placeholder_token]) pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer)) pipe.text_encoder_2.resize_token_embeddings(len(pipe.tokenizer_2)) token_embeds1 = pipe.text_encoder.get_input_embeddings().weight.data token_embeds2 = pipe.text_encoder_2.get_input_embeddings().weight.data with torch.no_grad(): for token_id in placeholder_token_ids1: token_embeds1[token_id] = token_embeds1[initializer_token_id1].clone() for token_id in placeholder_token_ids2: token_embeds2[token_id] = token_embeds2[initializer_token_id2].clone() class NestedAdapterInference: def __init__( self, sd_pipe, image_encoder_path, adapter_ckpt, resampler_num_queries, vq_normalize_factor, device, ): self.device = device self.image_encoder_path = image_encoder_path self.adapter_ckpt = adapter_ckpt self.vq_normalize_factor = vq_normalize_factor self.pipe = sd_pipe.to(self.device) self.set_nested_adapter() # load image encoder self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( self.image_encoder_path, use_safetensors=True ).to(self.device, dtype=torch.float16) self.clip_image_processor = CLIPImageProcessor() # spatial features model self.qformer = Resampler( dim=self.pipe.unet.config.cross_attention_dim, depth=4, dim_head=64, heads=12, num_queries=resampler_num_queries, embedding_dim=self.image_encoder.config.hidden_size, output_dim=self.pipe.unet.config.cross_attention_dim, ff_mult=4, ).to(self.device, dtype=torch.float16) if adapter_ckpt is not None: self.load_nested_adapter() def set_nested_adapter(self): unet = self.pipe.unet attn_procs = {} for name in unet.attn_processors.keys(): cross_attention_dim = ( None if name.endswith("attn1.processor") else unet.config.cross_attention_dim ) if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] if cross_attention_dim is None: attn_procs[name] = AttnProcessor() else: attn_procs[name] = NestedAttnProcessor( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, normalize_factor=self.vq_normalize_factor, ).to(self.device, dtype=torch.float16) unet.set_attn_processor(attn_procs) def load_nested_adapter(self): state_dict = {"adapter_modules": {}, "qformer": {}} f = load_file(self.adapter_ckpt) for key in f.keys(): if key.startswith("adapter_modules."): state_dict["adapter_modules"][key.replace("adapter_modules.", "")] = f[ key ] elif key.startswith("spatial_features_model."): state_dict["qformer"][key.replace("spatial_features_model.", "")] = f[ key ] self.qformer.load_state_dict(state_dict["qformer"]) adapter_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) adapter_layers.load_state_dict(state_dict["adapter_modules"]) @torch.inference_mode() def get_image_embeds(self, pil_image=None, clip_image_embeds=None): if isinstance(pil_image, Image.Image): pil_image = [pil_image] clip_image = self.clip_image_processor( images=pil_image, return_tensors="pt" ).pixel_values clip_image_embeds = self.image_encoder( clip_image.to(self.device, dtype=torch.float16) ) spatial_clip_image_embeds = clip_image_embeds.last_hidden_state spatial_clip_image_embeds = spatial_clip_image_embeds[:, 1:] # remove CLS token return spatial_clip_image_embeds def generate( self, pil_image=None, clip_image_embeds=None, prompt=None, placeholder_token_ids=None, negative_prompt=None, scale=1.0, num_samples=4, seed=None, guidance_scale=5.0, num_inference_steps=30, multiple_images=False, special_token_weight=1.0, **kwargs, ): if pil_image is not None: num_prompts = ( 1 if isinstance(pil_image, Image.Image) or multiple_images else len(pil_image) ) else: num_prompts = clip_image_embeds.size(0) if prompt is None: prompt = "best quality, high quality" if negative_prompt is None: negative_prompt = ( "monochrome, lowres, bad anatomy, worst quality, low quality" ) if not isinstance(prompt, List): prompt = [prompt] * num_prompts if not isinstance(negative_prompt, List): negative_prompt = [negative_prompt] * num_prompts text_input_ids = self.pipe.tokenizer( prompt, max_length=self.pipe.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt", ).input_ids special_token_indices = (text_input_ids == placeholder_token_ids[0]).nonzero()[ :, 1 ] spatial_clip_image_embeds = self.get_image_embeds( pil_image=pil_image, clip_image_embeds=clip_image_embeds ) # (bs, 256, 1280) with torch.no_grad(): ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.pipe.encode_prompt( prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt, ) special_token_indices = (text_input_ids == placeholder_token_ids[0]).nonzero()[ :, 1 ] with torch.no_grad(): qformer_tokens_out = self.qformer(spatial_clip_image_embeds) if multiple_images: b, num_tokens, d = qformer_tokens_out.shape qformer_tokens_out = qformer_tokens_out.reshape( 1, num_tokens * b, d ) bs_embed, num_tokens, _ = qformer_tokens_out.shape qformer_tokens_out = qformer_tokens_out.repeat(1, num_samples, 1, 1) qformer_tokens_out = qformer_tokens_out.view( bs_embed * num_samples, num_tokens, -1 ) qformer_tokens_out = qformer_tokens_out.repeat_interleave(2, dim=0) cross_attention_kwargs = { "qformer_tokens_out": qformer_tokens_out, "special_token_indices": special_token_indices, "special_token_weight": special_token_weight, "inference_mode": True, } generator = get_generator(seed, self.device) images = self.pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, cross_attention_kwargs=cross_attention_kwargs, **kwargs, ).images return images