Spaces:
Runtime error
Runtime error
| from transformers import PretrainedConfig | |
| from PIL import Image | |
| import torch | |
| import numpy as np | |
| import PIL | |
| import os | |
| from tqdm.auto import tqdm | |
| from diffusers.models.attention_processor import ( | |
| AttnProcessor2_0, | |
| LoRAAttnProcessor2_0, | |
| LoRAXFormersAttnProcessor, | |
| XFormersAttnProcessor, | |
| ) | |
| device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') | |
| def myroll2d(a, delta_x, delta_y): | |
| h, w = a.shape[0], a.shape[1] | |
| delta_x = -delta_x | |
| delta_y = -delta_y | |
| if isinstance(a, np.ndarray): | |
| b = np.zeros ([h,w]).astype(np.uint8) | |
| elif isinstance(a, torch.Tensor): | |
| b = torch.zeros([h,w]).to(torch.uint8) | |
| if delta_x > 0: | |
| left_a = delta_x | |
| right_a = w | |
| left_b = 0 | |
| right_b = w - delta_x | |
| else: | |
| left_a = 0 | |
| right_a = w + delta_x | |
| left_b = -delta_x | |
| right_b = w | |
| if delta_y > 0: | |
| top_a = delta_y | |
| bot_a = h | |
| top_b = 0 | |
| bot_b = h-delta_y | |
| else: | |
| top_a = 0 | |
| bot_a = h + delta_y | |
| top_b = -delta_y | |
| bot_b = h | |
| b[left_b: right_b, top_b: bot_b] = a[left_a: right_a, top_a: bot_a] | |
| return b | |
| def import_model_class_from_model_name_or_path( | |
| pretrained_model_name_or_path: str, revision = None, subfolder: str = "text_encoder" | |
| ): | |
| text_encoder_config = PretrainedConfig.from_pretrained( | |
| pretrained_model_name_or_path, subfolder=subfolder, revision=revision | |
| ) | |
| model_class = text_encoder_config.architectures[0] | |
| if model_class == "CLIPTextModel": | |
| from transformers import CLIPTextModel | |
| return CLIPTextModel | |
| elif model_class == "CLIPTextModelWithProjection": | |
| from transformers import CLIPTextModelWithProjection | |
| return CLIPTextModelWithProjection | |
| else: | |
| raise ValueError(f"{model_class} is not supported.") | |
| def image2latent(image, vae = None, dtype=None): | |
| with torch.no_grad(): | |
| if type(image) is Image or type(image) is PIL.PngImagePlugin.PngImageFile or type(image) is PIL.JpegImagePlugin.JpegImageFile: | |
| image = np.array(image) | |
| if type(image) is torch.Tensor and image.dim() == 4: | |
| latents = image | |
| else: | |
| image = torch.from_numpy(image).float() / 127.5 - 1 | |
| image = image.permute(2, 0, 1).unsqueeze(0).to(device, dtype= dtype) | |
| latents = vae.encode(image).latent_dist.sample() | |
| latents = latents * vae.config.scaling_factor | |
| return latents | |
| def latent2image(latents, return_type = 'np', vae = None): | |
| # needs_upcasting = vae.dtype == torch.float16 and vae.config.force_upcast | |
| needs_upcasting = True | |
| if needs_upcasting: | |
| upcast_vae(vae) | |
| latents = latents.to(next(iter(vae.post_quant_conv.parameters())).dtype) | |
| image = vae.decode(latents /vae.config.scaling_factor, return_dict=False)[0] | |
| if return_type == 'np': | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.cpu().permute(0, 2, 3, 1).numpy()#[0] | |
| image = (image * 255).astype(np.uint8) | |
| if needs_upcasting: | |
| vae.to(dtype=torch.float16) | |
| return image | |
| def upcast_vae(vae): | |
| dtype = vae.dtype | |
| vae.to(dtype=torch.float32) | |
| use_torch_2_0_or_xformers = isinstance( | |
| vae.decoder.mid_block.attentions[0].processor, | |
| ( | |
| AttnProcessor2_0, | |
| XFormersAttnProcessor, | |
| LoRAXFormersAttnProcessor, | |
| LoRAAttnProcessor2_0, | |
| ), | |
| ) | |
| # if xformers or torch_2_0 is used attention block does not need | |
| # to be in float32 which can save lots of memory | |
| if use_torch_2_0_or_xformers: | |
| vae.post_quant_conv.to(dtype) | |
| vae.decoder.conv_in.to(dtype) | |
| vae.decoder.mid_block.to(dtype) | |
| def prompt_to_emb_length_sdxl(prompt, tokenizer, text_encoder, length = None): | |
| text_input = tokenizer( | |
| [prompt], | |
| padding="max_length", | |
| max_length=length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| prompt_embeds = text_encoder(text_input.input_ids.to(device),output_hidden_states=True) | |
| pooled_prompt_embeds = prompt_embeds[0] | |
| prompt_embeds = prompt_embeds.hidden_states[-2] | |
| bs_embed, seq_len, _ = prompt_embeds.shape | |
| prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) | |
| pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) | |
| return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds} | |
| def prompt_to_emb_length_sd(prompt, tokenizer, text_encoder, length = None): | |
| text_input = tokenizer( | |
| [prompt], | |
| padding="max_length", | |
| max_length=length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| emb = text_encoder(text_input.input_ids.to(device))[0] | |
| return emb | |
| def sdxl_prepare_input_decom( | |
| set_string_list, | |
| tokenizer, | |
| tokenizer_2, | |
| text_encoder_1, | |
| text_encoder_2, | |
| length = 20, | |
| bsz = 1, | |
| weight_dtype = torch.float32, | |
| resolution = 1024, | |
| normal_token_id_list = [] | |
| ): | |
| encoder_hidden_states_list = [] | |
| pooled_prompt_embeds = 0 | |
| for m_idx in range(len(set_string_list)): | |
| prompt_embeds_list = [] | |
| if ("#" in set_string_list[m_idx] or "$" in set_string_list[m_idx]) and m_idx not in normal_token_id_list : ### | |
| out = prompt_to_emb_length_sdxl( | |
| set_string_list[m_idx], tokenizer, text_encoder_1, length = length | |
| ) | |
| else: | |
| out = prompt_to_emb_length_sdxl( | |
| set_string_list[m_idx], tokenizer, text_encoder_1, length = 77 | |
| ) | |
| print(m_idx, set_string_list[m_idx]) | |
| prompt_embeds, _ = out["prompt_embeds"].to(dtype=weight_dtype), out["pooled_prompt_embeds"].to(dtype=weight_dtype) | |
| prompt_embeds = prompt_embeds.repeat(bsz, 1, 1) | |
| prompt_embeds_list.append(prompt_embeds) | |
| if ("#" in set_string_list[m_idx] or "$" in set_string_list[m_idx]) and m_idx not in normal_token_id_list: | |
| out = prompt_to_emb_length_sdxl( | |
| set_string_list[m_idx], tokenizer_2, text_encoder_2, length = length | |
| ) | |
| else: | |
| out = prompt_to_emb_length_sdxl( | |
| set_string_list[m_idx], tokenizer_2, text_encoder_2, length = 77 | |
| ) | |
| print(m_idx, set_string_list[m_idx]) | |
| prompt_embeds = out["prompt_embeds"].to(dtype=weight_dtype) | |
| pooled_prompt_embeds += out["pooled_prompt_embeds"].to(dtype=weight_dtype) | |
| prompt_embeds = prompt_embeds.repeat(bsz, 1, 1) | |
| prompt_embeds_list.append(prompt_embeds) | |
| encoder_hidden_states_list.append(torch.concat(prompt_embeds_list, dim=-1)) | |
| add_text_embeds = pooled_prompt_embeds /len(set_string_list) | |
| target_size, original_size,crops_coords_top_left = (resolution,resolution),(resolution,resolution),(0,0) | |
| add_time_ids = list(original_size + crops_coords_top_left + target_size) | |
| add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype,device = pooled_prompt_embeds.device) #[B,6] | |
| return encoder_hidden_states_list, add_text_embeds, add_time_ids | |
| def sd_prepare_input_decom( | |
| set_string_list, | |
| tokenizer, | |
| text_encoder_1, | |
| length = 20, | |
| bsz = 1, | |
| weight_dtype = torch.float32, | |
| normal_token_id_list = [] | |
| ): | |
| encoder_hidden_states_list = [] | |
| for m_idx in range(len(set_string_list)): | |
| if ("#" in set_string_list[m_idx] or "$" in set_string_list[m_idx]) and m_idx not in normal_token_id_list : ### | |
| encoder_hidden_states = prompt_to_emb_length_sd( | |
| set_string_list[m_idx], tokenizer, text_encoder_1, length = length | |
| ) | |
| else: | |
| encoder_hidden_states = prompt_to_emb_length_sd( | |
| set_string_list[m_idx], tokenizer, text_encoder_1, length = 77 | |
| ) | |
| print(m_idx, set_string_list[m_idx]) | |
| encoder_hidden_states = encoder_hidden_states.repeat(bsz, 1, 1) | |
| encoder_hidden_states_list.append(encoder_hidden_states.to(dtype=weight_dtype)) | |
| return encoder_hidden_states_list | |
| def load_mask (input_folder): | |
| np_mask_dtype = 'uint8' | |
| mask_np_list = [] | |
| mask_label_list = [] | |
| files = [ | |
| file_name for file_name in os.listdir(input_folder) \ | |
| if "mask" in file_name and ".npy" in file_name \ | |
| and "_" in file_name and "Edited" not in file_name | |
| ] | |
| files = sorted(files, key = lambda x: int(x.split("_")[0][4:])) | |
| for idx, file_name in enumerate(files): | |
| if "mask" in file_name and ".npy" in file_name and "_" in file_name \ | |
| and "Edited" not in file_name: | |
| mask_np = np.load(os.path.join(input_folder, file_name)).astype(np_mask_dtype) | |
| mask_np_list.append(mask_np) | |
| mask_label = file_name.split("_")[1][:-4] | |
| mask_label_list.append(mask_label) | |
| mask_list = [] | |
| for mask_np in mask_np_list: | |
| mask = torch.from_numpy(mask_np) | |
| mask_list.append(mask) | |
| try: | |
| assert torch.all(sum(mask_list)==1) | |
| except: | |
| print("please check mask") | |
| # plt.imsave( "out_mask.png", mask_list_edit[0]) | |
| import pdb; pdb.set_trace() | |
| return mask_list, mask_label_list | |
| def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512): | |
| if type(image_path) is str: | |
| image = np.array(Image.open(image_path))[:, :, :3] | |
| else: | |
| image = image_path | |
| h, w, c = image.shape | |
| left = min(left, w-1) | |
| right = min(right, w - left - 1) | |
| top = min(top, h - left - 1) | |
| bottom = min(bottom, h - top - 1) | |
| image = image[top:h-bottom, left:w-right] | |
| h, w, c = image.shape | |
| if h < w: | |
| offset = (w - h) // 2 | |
| image = image[:, offset:offset + h] | |
| elif w < h: | |
| offset = (h - w) // 2 | |
| image = image[offset:offset + w] | |
| image = np.array(Image.fromarray(image).resize((size, size))) | |
| return image | |
| def mask_union_torch(*masks): | |
| masks = [m.to(torch.float) for m in masks] | |
| res = sum(masks)>0 | |
| return res | |
| def load_mask_edit(input_folder): | |
| np_mask_dtype = 'uint8' | |
| mask_np_list = [] | |
| mask_label_list = [] | |
| files = [file_name for file_name in os.listdir(input_folder) if "mask" in file_name and ".npy" in file_name and "_" in file_name and "Edited" in file_name and "-1" not in file_name] | |
| files = sorted(files, key = lambda x: int(x.split("_")[0][10:])) | |
| for idx, file_name in enumerate(files): | |
| if "mask" in file_name and ".npy" in file_name and "_" in file_name and "Edited" in file_name and "-1" not in file_name: | |
| mask_np = np.load(os.path.join(input_folder, file_name)).astype(np_mask_dtype) | |
| mask_np_list.append(mask_np) | |
| mask_label = file_name.split("_")[1][:-4] | |
| # mask_label = mask_label.split("-")[0] | |
| mask_label_list.append(mask_label) | |
| mask_list = [] | |
| for mask_np in mask_np_list: | |
| mask = torch.from_numpy(mask_np) | |
| mask_list.append(mask) | |
| try: | |
| assert torch.all(sum(mask_list)==1) | |
| except: | |
| print("Make sure maskEdited is in the folder, if not, generate using the UI") | |
| import pdb; pdb.set_trace() | |
| return mask_list, mask_label_list | |
| def save_images(images,filename, num_rows=1, offset_ratio=0.02): | |
| if type(images) is list: | |
| num_empty = len(images) % num_rows | |
| elif images.ndim == 4: | |
| num_empty = images.shape[0] % num_rows | |
| else: | |
| images = [images] | |
| num_empty = 0 | |
| empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 | |
| images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty | |
| num_items = len(images) | |
| folder = os.path.dirname(filename) | |
| for i, image in enumerate(images): | |
| pil_img = Image.fromarray(image) | |
| name = filename.split("/")[-1] | |
| name = name.split(".")[-2]+"_{}".format(i) +"."+filename.split(".")[-1] | |
| pil_img.save(os.path.join(folder, name)) | |
| print("saved to ", os.path.join(folder, name)) |