import torch from typing import Optional from PIL import Image from diffusers import AutoencoderKL, EulerDiscreteScheduler, EDMDPMSolverMultistepScheduler from transformers import ( CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, ) from scipy.spatial.distance import cdist import numpy as np import unet.pipeline_stable_diffusion_xl as pipeline_stable_diffusion_xl from torch.fft import fftn, fftshift, ifftn, ifftshift from typing import Optional, Tuple from unet.unet import UNet2DConditionModel from unet.unet_controller import UNetController def ipca(q, k, v, scale, unet_controller: Optional[UNetController] = None): # eg. q: [4,20,1024,64] k,v: [4,20,77,64] q_neg, q_pos = torch.split(q, q.size(0) // 2, dim=0) k_neg, k_pos = torch.split(k, k.size(0) // 2, dim=0) v_neg, v_pos = torch.split(v, v.size(0) // 2, dim=0) # 1. negative_attn scores_neg = torch.matmul(q_neg, k_neg.transpose(-2, -1)) * scale attn_weights_neg = torch.softmax(scores_neg, dim=-1) attn_output_neg = torch.matmul(attn_weights_neg, v_neg) # 2. positive_attn (we do ipca only on positive branch) # 2.1 ipca k_plus = torch.cat(tuple(k_pos.transpose(-2, -1)), dim=2).unsqueeze(0).repeat(k_pos.size(0),1,1,1) # 𝐾+ = [𝐾1 βŠ• 𝐾2 βŠ• . . . βŠ• 𝐾𝑁 ] v_plus = torch.cat(tuple(v_pos), dim=1).unsqueeze(0).repeat(v_pos.size(0),1,1,1) # 𝑉+ = [𝑉1 βŠ• 𝑉2 βŠ• . . . βŠ• 𝑉𝑁 ] # 2.2 apply mask if unet_controller is not None: scores_pos = torch.matmul(q_pos, k_plus) * scale # 2.2.1 apply dropout mask dropout_mask = gen_dropout_mask(scores_pos.shape, unet_controller, unet_controller.Ipca_dropout) # eg: [a,1024,154] # 2.2.3 apply embeds mask if unet_controller.Use_embeds_mask: apply_embeds_mask(unet_controller,dropout_mask, add_eot=False) mask = dropout_mask mask = mask.unsqueeze(1).repeat(1,scores_pos.size(1),1,1) attn_weights_pos = torch.softmax(scores_pos + torch.log(mask), dim=-1) else: scores_pos = torch.matmul(q_pos, k_plus) * scale attn_weights_pos = torch.softmax(scores_pos, dim=-1) attn_output_pos = torch.matmul(attn_weights_pos, v_plus) # 3. combine attn_output = torch.cat((attn_output_neg, attn_output_pos), dim=0) return attn_output def ipca2(q, k, v, scale, unet_controller: Optional[UNetController] = None): # eg. q: [4,20,1024,64] k,v: [4,20,77,64] if unet_controller.ipca_time_step != unet_controller.current_time_step: unet_controller.ipca_time_step = unet_controller.current_time_step unet_controller.ipca2_index = 0 else: unet_controller.ipca2_index += 1 if unet_controller.Store_qkv is True: key = f"cross {unet_controller.current_time_step} {unet_controller.current_unet_position} {unet_controller.ipca2_index}" unet_controller.k_store[key] = k unet_controller.v_store[key] = v scores = torch.matmul(q, k.transpose(-2, -1)) * scale attn_weights = torch.softmax(scores, dim=-1) attn_output = torch.matmul(attn_weights, v) else: # batch > 1 if unet_controller.frame_prompt_express_list is not None: batch_size = q.size(0) // 2 attn_output_list = [] for i in range(batch_size): q_i = q[[i, i + batch_size], :, :, :] k_i = k[[i, i + batch_size], :, :, :] v_i = v[[i, i + batch_size], :, :, :] q_neg_i, q_pos_i = torch.split(q_i, q_i.size(0) // 2, dim=0) k_neg_i, k_pos_i = torch.split(k_i, k_i.size(0) // 2, dim=0) v_neg_i, v_pos_i = torch.split(v_i, v_i.size(0) // 2, dim=0) key = f"cross {unet_controller.current_time_step} {unet_controller.current_unet_position} {unet_controller.ipca2_index}" q_store = q_i k_store = unet_controller.k_store[key] v_store = unet_controller.v_store[key] q_store_neg, q_store_pos = torch.split(q_store, q_store.size(0) // 2, dim=0) k_store_neg, k_store_pos = torch.split(k_store, k_store.size(0) // 2, dim=0) v_store_neg, v_store_pos = torch.split(v_store, v_store.size(0) // 2, dim=0) q_neg = torch.cat((q_neg_i, q_store_neg), dim=0) q_pos = torch.cat((q_pos_i, q_store_pos), dim=0) k_neg = torch.cat((k_neg_i, k_store_neg), dim=0) k_pos = torch.cat((k_pos_i, k_store_pos), dim=0) v_neg = torch.cat((v_neg_i, v_store_neg), dim=0) v_pos = torch.cat((v_pos_i, v_store_pos), dim=0) q_i = torch.cat((q_neg, q_pos), dim=0) k_i = torch.cat((k_neg, k_pos), dim=0) v_i = torch.cat((v_neg, v_pos), dim=0) attn_output_i = ipca(q_i, k_i, v_i, scale, unet_controller) attn_output_i = attn_output_i[[0, 2], :, :, :] attn_output_list.append(attn_output_i) attn_output_ = torch.cat(attn_output_list, dim=0) attn_output = torch.zeros(size=(q.size(0), attn_output_i.size(1), attn_output_i.size(2), attn_output_i.size(3)), device=q.device, dtype=q.dtype) for i in range(batch_size): attn_output[i] = attn_output_[i*2] for i in range(batch_size): attn_output[i + batch_size] = attn_output_[i*2 + 1] # batch = 1 else: q_neg, q_pos = torch.split(q, q.size(0) // 2, dim=0) k_neg, k_pos = torch.split(k, k.size(0) // 2, dim=0) v_neg, v_pos = torch.split(v, v.size(0) // 2, dim=0) key = f"cross {unet_controller.current_time_step} {unet_controller.current_unet_position} {unet_controller.ipca2_index}" q_store = q k_store = unet_controller.k_store[key] v_store = unet_controller.v_store[key] q_store_neg, q_store_pos = torch.split(q_store, q_store.size(0) // 2, dim=0) k_store_neg, k_store_pos = torch.split(k_store, k_store.size(0) // 2, dim=0) v_store_neg, v_store_pos = torch.split(v_store, v_store.size(0) // 2, dim=0) q_neg = torch.cat((q_neg, q_store_neg), dim=0) q_pos = torch.cat((q_pos, q_store_pos), dim=0) k_neg = torch.cat((k_neg, k_store_neg), dim=0) k_pos = torch.cat((k_pos, k_store_pos), dim=0) v_neg = torch.cat((v_neg, v_store_neg), dim=0) v_pos = torch.cat((v_pos, v_store_pos), dim=0) q = torch.cat((q_neg, q_pos), dim=0) k = torch.cat((k_neg, k_pos), dim=0) v = torch.cat((v_neg, v_pos), dim=0) attn_output = ipca(q, k, v, scale, unet_controller) attn_output = attn_output[[0, 2], :, :, :] return attn_output def apply_embeds_mask(unet_controller: Optional[UNetController],dropout_mask, add_eot=False): id_prompt = unet_controller.id_prompt prompt_tokens = prompt2tokens(unet_controller.tokenizer,unet_controller.prompts[0]) words_tokens = prompt2tokens(unet_controller.tokenizer,id_prompt) words_tokens = [word for word in words_tokens if word != '<|endoftext|>' and word != '<|startoftext|>'] index_of_words = find_sublist_index(prompt_tokens,words_tokens) index_list = [index+77 for index in range(index_of_words, index_of_words+len(words_tokens))] if add_eot: index_list.extend([index+77 for index, word in enumerate(prompt_tokens) if word == '<|endoftext|>']) mask_indices = torch.arange(dropout_mask.size(-1), device=dropout_mask.device) mask = (mask_indices >= 78) & (~torch.isin(mask_indices, torch.tensor(index_list, device=dropout_mask.device))) dropout_mask[0, :, mask] = 0 def gen_dropout_mask(out_shape, unet_controller: Optional[UNetController], drop_out): gen_length = out_shape[3] attn_map_side_length = out_shape[2] batch_num = out_shape[0] mask_list = [] for prompt_index in range(batch_num): start = prompt_index * int(gen_length / batch_num) end = (prompt_index + 1) * int(gen_length / batch_num) mask = torch.bernoulli(torch.full((attn_map_side_length,gen_length), 1 - drop_out, dtype=unet_controller.torch_dtype, device=unet_controller.device)) mask[:, start:end] = 1 mask_list.append(mask) concatenated_mask = torch.stack(mask_list, dim=0) return concatenated_mask def load_pipe_from_path(model_path, device, torch_dtype, variant): model_name = model_path.split('/')[-1] if model_path.split('/')[-1] == 'playground-v2.5-1024px-aesthetic': scheduler = EDMDPMSolverMultistepScheduler.from_pretrained(model_path, subfolder="scheduler", torch_dtype=torch_dtype, variant=variant,) else: scheduler = EulerDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler", torch_dtype=torch_dtype, variant=variant,) if model_path.split('/')[-1] == 'Juggernaut-X-v10' or model_path.split('/')[-1] == 'Juggernaut-XI-v11': variant = None vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae", torch_dtype=torch_dtype, variant=variant,) tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer", torch_dtype=torch_dtype, variant=variant,) tokenizer_2 = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer_2", torch_dtype=torch_dtype, variant=variant,) text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=torch_dtype, variant=variant,) text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(model_path, subfolder="text_encoder_2", torch_dtype=torch_dtype, variant=variant,) unet_new = UNet2DConditionModel.from_pretrained(model_path, subfolder="unet", torch_dtype=torch_dtype, variant=variant,) pipe = pipeline_stable_diffusion_xl.StableDiffusionXLPipeline( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet_new, scheduler=scheduler, ) pipe.to(device) return pipe, model_name def get_max_window_length(unet_controller: Optional[UNetController],id_prompt, frame_prompt_list): single_long_prompt = id_prompt max_window_length = 0 for index, movement in enumerate(frame_prompt_list): single_long_prompt += ' ' + movement token_length = len(single_long_prompt.split()) if token_length >= 77: break max_window_length += 1 return max_window_length def movement_gen_story_slide_windows(id_prompt, frame_prompt_list, pipe, window_length, seed, unet_controller: Optional[UNetController], save_dir, verbose=True): import os max_window_length = get_max_window_length(unet_controller,id_prompt,frame_prompt_list) window_length = min(window_length,max_window_length) if window_length < len(frame_prompt_list): movement_lists = circular_sliding_windows(frame_prompt_list, window_length) else: movement_lists = [movement for movement in frame_prompt_list] story_images = [] if verbose: print("seed:", seed) generate = torch.Generator().manual_seed(seed) unet_controller.id_prompt = id_prompt for index, movement in enumerate(frame_prompt_list): if unet_controller is not None: if window_length < len(frame_prompt_list): unet_controller.frame_prompt_suppress = movement_lists[index][1:] unet_controller.frame_prompt_express = movement_lists[index][0] gen_propmts = [f'{id_prompt} {" ".join(movement_lists[index])}'] else: unet_controller.frame_prompt_suppress = movement_lists[:index] + movement_lists[index+1:] unet_controller.frame_prompt_express = movement_lists[index] gen_propmts = [f'{id_prompt} {" ".join(movement_lists)}'] if verbose: print(f"suppress: {unet_controller.frame_prompt_suppress}") print(f"express: {unet_controller.frame_prompt_express}") print(f'id_prompt: {id_prompt}') print(f"gen_propmts: {gen_propmts}") else: gen_propmts = f'{id_prompt} {movement}' if unet_controller is not None and unet_controller.Use_same_init_noise is True: generate = torch.Generator().manual_seed(seed) images = pipe(gen_propmts, generator=generate, unet_controller=unet_controller).images story_images.append(images[0]) images[0].save(os.path.join(save_dir, f'{id_prompt} {unet_controller.frame_prompt_express}.jpg')) image_array_list = [np.array(pil_img) for pil_img in story_images] # Concatenate images horizontally story_image = np.concatenate(image_array_list, axis=1) story_image = Image.fromarray(story_image.astype(np.uint8)) if unet_controller.Save_story_image: story_image.save(os.path.join(save_dir, f'story_image_{id_prompt}.jpg')) return story_images, story_image # this function set batch > 1 to generate multiple images at once def movement_gen_story_slide_windows_batch(id_prompt, frame_prompt_list, pipe, window_length, seed, unet_controller: Optional[UNetController], save_dir, batch_size=3): import os max_window_length = get_max_window_length(unet_controller,id_prompt,frame_prompt_list) window_length = min(window_length,max_window_length) if window_length < len(frame_prompt_list): movement_lists = circular_sliding_windows(frame_prompt_list, window_length) else: movement_lists = [movement for movement in frame_prompt_list] story_images = [] print("seed:", seed) generate = torch.Generator().manual_seed(seed) unet_controller.id_prompt = id_prompt gen_prompt_info_list = [] gen_prompt = None for index, _ in enumerate(frame_prompt_list): if window_length < len(frame_prompt_list): frame_prompt_suppress = movement_lists[index][1:] frame_prompt_express = movement_lists[index][0] gen_prompt = f'{id_prompt} {" ".join(movement_lists[index])}' else: frame_prompt_suppress = movement_lists[:index] + movement_lists[index+1:] frame_prompt_express = movement_lists[index] gen_prompt = f'{id_prompt} {" ".join(movement_lists)}' gen_prompt_info_list.append({'frame_prompt_suppress': frame_prompt_suppress, 'frame_prompt_express': frame_prompt_express}) story_images = [] for i in range(0, len(gen_prompt_info_list), batch_size): batch = gen_prompt_info_list[i:i + batch_size] gen_prompts = [gen_prompt for _ in batch] unet_controller.frame_prompt_express_list = [gen_prompt_info['frame_prompt_express'] for gen_prompt_info in batch] unet_controller.frame_prompt_suppress_list = [gen_prompt_info['frame_prompt_suppress'] for gen_prompt_info in batch] if unet_controller is not None and unet_controller.Use_same_init_noise is True: generate = torch.Generator().manual_seed(seed) images = pipe(gen_prompts, generator=generate, unet_controller=unet_controller).images for index,image in enumerate(images): story_images.append(image) image.save(os.path.join(save_dir, f'{id_prompt} {unet_controller.frame_prompt_express_list[index]}.jpg')) image_array_list = [np.array(pil_img) for pil_img in story_images] # Concatenate images horizontally story_image = np.concatenate(image_array_list, axis=1) story_image = Image.fromarray(story_image.astype(np.uint8)) if unet_controller.Save_story_image: story_image.save(os.path.join(save_dir, 'story_image.jpg')) return story_images, story_image def prompt2tokens(tokenizer, prompt): text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids tokens = [] for text_input_id in text_input_ids[0]: token = tokenizer.decoder[text_input_id.item()] tokens.append(token) return tokens def punish_wight(tensor, latent_size, alpha=1.0, beta=1.2, calc_similarity=False): u, s, vh = torch.linalg.svd(tensor) u = u[:,:latent_size] zero_idx = int(latent_size * alpha) if calc_similarity: _s = s.clone() _s *= torch.exp(-alpha*_s) * beta _s[zero_idx:] = 0 _tensor = u @ torch.diag(_s) @ vh dist = cdist(tensor[:,0].unsqueeze(0).cpu(), _tensor[:,0].unsqueeze(0).cpu(), metric='cosine') print(f'The distance between the word embedding before and after the punishment: {dist}') s *= torch.exp(-alpha*s) * beta tensor = u @ torch.diag(s) @ vh return tensor def swr_single_prompt_embeds(swr_words,prompt_embeds,prompt,tokenizer,alpha=1.0, beta=1.2, zero_eot=False): punish_indices = [] prompt_tokens = prompt2tokens(tokenizer,prompt) words_tokens = prompt2tokens(tokenizer,swr_words) words_tokens = [word for word in words_tokens if word != '<|endoftext|>' and word != '<|startoftext|>'] index_of_words = find_sublist_index(prompt_tokens,words_tokens) if index_of_words != -1: punish_indices.extend([num for num in range(index_of_words, index_of_words+len(words_tokens))]) if zero_eot: eot_indices = [index for index, word in enumerate(prompt_tokens) if word == '<|endoftext|>'] prompt_embeds[eot_indices] *= 9e-1 pass else: punish_indices.extend([index for index, word in enumerate(prompt_tokens) if word == '<|endoftext|>']) punish_indices = list(set(punish_indices)) wo_batch = prompt_embeds[punish_indices] wo_batch = punish_wight(wo_batch.T.to(float), wo_batch.size(0), alpha=alpha, beta=beta, calc_similarity=False).T.to(prompt_embeds.dtype) prompt_embeds[punish_indices] = wo_batch def find_sublist_index(list1, list2): for i in range(len(list1) - len(list2) + 1): if list1[i:i + len(list2)] == list2: return i return -1 # If sublist is not found def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor": """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497). This version of the method comes from here: https://github.com/huggingface/diffusers/pull/5164#issuecomment-1732638706 """ x = x_in B, C, H, W = x.shape x = x.to(dtype=torch.float32) # FFT x_freq = fftn(x, dim=(-2, -1)) x_freq = fftshift(x_freq, dim=(-2, -1)) B, C, H, W = x_freq.shape mask = torch.ones((B, C, H, W), device=x.device) crow, ccol = H // 2, W // 2 mask[..., crow - threshold : crow + threshold, ccol - threshold : ccol + threshold] = scale x_freq = x_freq * mask # IFFT x_freq = ifftshift(x_freq, dim=(-2, -1)) x_filtered = ifftn(x_freq, dim=(-2, -1)).real return x_filtered.to(dtype=x_in.dtype) def apply_freeu( resolution_idx: int, hidden_states: "torch.Tensor", res_hidden_states: "torch.Tensor", **freeu_kwargs ) -> Tuple["torch.Tensor", "torch.Tensor"]: """Applies the FreeU mechanism as introduced in https: //arxiv.org/abs/2309.11497. Adapted from the official code repository: https://github.com/ChenyangSi/FreeU. Args: resolution_idx (`int`): Integer denoting the UNet block where FreeU is being applied. hidden_states (`torch.Tensor`): Inputs to the underlying block. res_hidden_states (`torch.Tensor`): Features from the skip block corresponding to the underlying block. s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features. s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features. b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. """ if resolution_idx == 0: num_half_channels = hidden_states.shape[1] // 2 hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b1"] res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s1"]) if resolution_idx == 1: num_half_channels = hidden_states.shape[1] // 2 hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b2"] res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s2"]) return hidden_states, res_hidden_states def circular_sliding_windows(lst, w): n = len(lst) windows = [] for i in range(n): window = [lst[(i + j) % n] for j in range(w)] windows.append(window) return windows