from typing import Any, Dict, Optional from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.schedulers import KarrasDiffusionSchedulers import numpy import torch import torch.nn as nn import torch.utils.checkpoint import torch.distributed import transformers from collections import OrderedDict from PIL import Image from torchvision import transforms from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer import diffusers from diffusers import ( AutoencoderKL, DDPMScheduler, DiffusionPipeline, EulerAncestralDiscreteScheduler, UNet2DConditionModel, ImagePipelineOutput ) from diffusers.image_processor import VaeImageProcessor from diffusers.models.attention_processor import Attention, AttnProcessor, XFormersAttnProcessor, AttnProcessor2_0 from diffusers.utils.import_utils import is_xformers_available ### AutoencoderKL, UNet2DConditionModel ,DDPMScheduler 這幾個用於生成模型的核心模塊,負責編碼.擴散過程和調度 ### 接收一個 PIL 圖像物件,並將該圖像轉換為 RGB 格式 def to_rgb_image(maybe_rgba: Image.Image): if maybe_rgba.mode == 'RGB': return maybe_rgba elif maybe_rgba.mode == 'RGBA': # A為透明度 rgba = maybe_rgba ## 創建一個隨機的 RGB 圖像,尺寸與原始 RGBA 圖像相同 img = numpy.random.randint(255, 256, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8) img = Image.fromarray(img, 'RGB') ##將這個 NumPy 陣列轉換為 PIL 的 RGB 圖像 img.paste(rgba, mask=rgba.getchannel('A')) return img else: raise ValueError("Unsupported image type.", maybe_rgba.mode) #### RGB相對RGBA來說與大多數顯示設備兼容,簡單且高效,更加簡單且資源友好 #### 並且是數位圖像處理和顯示的標準, 此專案中RGB已足夠 #### 無論是 JPEG、PNG 等圖片格式,還是 HTML 和 CSS 用於網頁設計的顏色表示,RGB 模式都是標準化的選擇。 class ReferenceOnlyAttnProc(torch.nn.Module): def __init__( self, chained_proc, enabled=False, name=None ) -> None: super().__init__() self.enabled = enabled self.chained_proc = chained_proc self.name = name def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, mode="w", ref_dict: dict = None, is_cfg_guidance = False ) -> Any: if encoder_hidden_states is None: encoder_hidden_states = hidden_states if self.enabled and is_cfg_guidance: res0 = self.chained_proc(attn, hidden_states[:1], encoder_hidden_states[:1], attention_mask) hidden_states = hidden_states[1:] encoder_hidden_states = encoder_hidden_states[1:] if self.enabled: if mode == 'w': ref_dict[self.name] = encoder_hidden_states elif mode == 'r': encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1) elif mode == 'm': encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict[self.name]], dim=1) else: assert False, mode res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask) if self.enabled and is_cfg_guidance: res = torch.cat([res0, res]) return res #### 一種靈活的注意力機制,它可以在訓練或推理過程中根據不同的模式("w"、"r"、"m") 進行操作。 #### 目的是讓模型對不同的輸入數據賦予不同的「權重」,從而突出重要的信息,忽略次要的細節。 class RefOnlyNoisedUNet(torch.nn.Module): def __init__(self, unet: UNet2DConditionModel, train_sched: DDPMScheduler, val_sched: EulerAncestralDiscreteScheduler) -> None: super().__init__() self.unet = unet self.train_sched = train_sched self.val_sched = val_sched unet_lora_attn_procs = dict() for name, _ in unet.attn_processors.items(): if torch.__version__ >= '2.0': default_attn_proc = AttnProcessor2_0() elif is_xformers_available(): default_attn_proc = XFormersAttnProcessor() else: default_attn_proc = AttnProcessor() unet_lora_attn_procs[name] = ReferenceOnlyAttnProc( default_attn_proc, enabled=name.endswith("attn1.processor"), name=name ) unet.set_attn_processor(unet_lora_attn_procs) def __getattr__(self, name: str): try: return super().__getattr__(name) except AttributeError: return getattr(self.unet, name) def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs): if is_cfg_guidance: encoder_hidden_states = encoder_hidden_states[1:] class_labels = class_labels[1:] self.unet( noisy_cond_lat, timestep, encoder_hidden_states=encoder_hidden_states, class_labels=class_labels, cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict), **kwargs ) def forward( self, sample, timestep, encoder_hidden_states, class_labels=None, *args, cross_attention_kwargs, down_block_res_samples=None, mid_block_res_sample=None, **kwargs ): cond_lat = cross_attention_kwargs['cond_lat'] is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False) noise = torch.randn_like(cond_lat) if self.training: noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep) noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep) else: noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1)) noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1)) ref_dict = {} self.forward_cond( noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs ) weight_dtype = self.unet.dtype return self.unet( sample, timestep, encoder_hidden_states, *args, class_labels=class_labels, cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance), down_block_additional_residuals=[ sample.to(dtype=weight_dtype) for sample in down_block_res_samples ] if down_block_res_samples is not None else None, mid_block_additional_residual=( mid_block_res_sample.to(dtype=weight_dtype) if mid_block_res_sample is not None else None ), **kwargs ) def scale_latents(latents): latents = (latents - 0.22) * 0.75 return latents def unscale_latents(latents): latents = latents / 0.75 + 0.22 return latents def scale_image(image): image = image * 0.5 / 0.8 return image def unscale_image(image): image = image / 0.5 * 0.8 return image class DepthControlUNet(torch.nn.Module): def __init__(self, unet: RefOnlyNoisedUNet, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0) -> None: super().__init__() self.unet = unet if controlnet is None: self.controlnet = diffusers.ControlNetModel.from_unet(unet.unet) else: self.controlnet = controlnet DefaultAttnProc = AttnProcessor2_0 if is_xformers_available(): DefaultAttnProc = XFormersAttnProcessor self.controlnet.set_attn_processor(DefaultAttnProc()) self.conditioning_scale = conditioning_scale def __getattr__(self, name: str): try: return super().__getattr__(name) except AttributeError: return getattr(self.unet, name) def forward(self, sample, timestep, encoder_hidden_states, class_labels=None, *args, cross_attention_kwargs: dict, **kwargs): cross_attention_kwargs = dict(cross_attention_kwargs) control_depth = cross_attention_kwargs.pop('control_depth') down_block_res_samples, mid_block_res_sample = self.controlnet( sample, timestep, encoder_hidden_states=encoder_hidden_states, controlnet_cond=control_depth, conditioning_scale=self.conditioning_scale, return_dict=False, ) return self.unet( sample, timestep, encoder_hidden_states=encoder_hidden_states, down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample, cross_attention_kwargs=cross_attention_kwargs ) class ModuleListDict(torch.nn.Module): def __init__(self, procs: dict) -> None: super().__init__() self.keys = sorted(procs.keys()) self.values = torch.nn.ModuleList(procs[k] for k in self.keys) def __getitem__(self, key): return self.values[self.keys.index(key)] class SuperNet(torch.nn.Module): def __init__(self, state_dict: Dict[str, torch.Tensor]): super().__init__() state_dict = OrderedDict((k, state_dict[k]) for k in sorted(state_dict.keys())) self.layers = torch.nn.ModuleList(state_dict.values()) self.mapping = dict(enumerate(state_dict.keys())) self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} # .processor for unet, .self_attn for text encoder self.split_keys = [".processor", ".self_attn"] # we add a hook to state_dict() and load_state_dict() so that the # naming fits with `unet.attn_processors` def map_to(module, state_dict, *args, **kwargs): new_state_dict = {} for key, value in state_dict.items(): num = int(key.split(".")[1]) # 0 is always "layers" new_key = key.replace(f"layers.{num}", module.mapping[num]) new_state_dict[new_key] = value return new_state_dict def remap_key(key, state_dict): for k in self.split_keys: if k in key: return key.split(k)[0] + k return key.split('.')[0] def map_from(module, state_dict, *args, **kwargs): all_keys = list(state_dict.keys()) for key in all_keys: replace_key = remap_key(key, state_dict) new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") state_dict[new_key] = state_dict[key] del state_dict[key] self._register_state_dict_hook(map_to) self._register_load_state_dict_pre_hook(map_from, with_module=True) class Zero123PlusPipeline(diffusers.StableDiffusionPipeline): tokenizer: transformers.CLIPTokenizer text_encoder: transformers.CLIPTextModel vision_encoder: transformers.CLIPVisionModelWithProjection feature_extractor_clip: transformers.CLIPImageProcessor unet: UNet2DConditionModel scheduler: diffusers.schedulers.KarrasDiffusionSchedulers vae: AutoencoderKL ramping: nn.Linear feature_extractor_vae: transformers.CLIPImageProcessor depth_transforms_multi = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, vision_encoder: transformers.CLIPVisionModelWithProjection, feature_extractor_clip: CLIPImageProcessor, feature_extractor_vae: CLIPImageProcessor, ramping_coefficients: Optional[list] = None, safety_checker=None, ): DiffusionPipeline.__init__(self) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=None, vision_encoder=vision_encoder, feature_extractor_clip=feature_extractor_clip, feature_extractor_vae=feature_extractor_vae ) self.register_to_config(ramping_coefficients=ramping_coefficients) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) def prepare(self): train_sched = DDPMScheduler.from_config(self.scheduler.config) if isinstance(self.unet, UNet2DConditionModel): self.unet = RefOnlyNoisedUNet(self.unet, train_sched, self.scheduler).eval() def add_controlnet(self, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0): self.prepare() self.unet = DepthControlUNet(self.unet, controlnet, conditioning_scale) return SuperNet(OrderedDict([('controlnet', self.unet.controlnet)])) def encode_condition_image(self, image: torch.Tensor): image = self.vae.encode(image).latent_dist.sample() return image @torch.no_grad() def __call__( self, image: Image.Image = None, prompt = "", *args, num_images_per_prompt: Optional[int] = 1, guidance_scale=4.0, depth_image: Image.Image = None, output_type: Optional[str] = "pil", width=640, height=960, num_inference_steps=28, return_dict=True, **kwargs ): self.prepare() if image is None: raise ValueError("Inputting embeddings not supported for this pipeline. Please pass an image.") assert not isinstance(image, torch.Tensor) image = to_rgb_image(image) image_1 = self.feature_extractor_vae(images=image, return_tensors="pt").pixel_values image_2 = self.feature_extractor_clip(images=image, return_tensors="pt").pixel_values if depth_image is not None and hasattr(self.unet, "controlnet"): depth_image = to_rgb_image(depth_image) depth_image = self.depth_transforms_multi(depth_image).to( device=self.unet.controlnet.device, dtype=self.unet.controlnet.dtype ) image = image_1.to(device=self.vae.device, dtype=self.vae.dtype) image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype) cond_lat = self.encode_condition_image(image) if guidance_scale > 1: negative_lat = self.encode_condition_image(torch.zeros_like(image)) cond_lat = torch.cat([negative_lat, cond_lat]) encoded = self.vision_encoder(image_2, output_hidden_states=False) global_embeds = encoded.image_embeds global_embeds = global_embeds.unsqueeze(-2) if hasattr(self, "encode_prompt"): encoder_hidden_states = self.encode_prompt( prompt, self.device, num_images_per_prompt, False )[0] else: encoder_hidden_states = self._encode_prompt( prompt, self.device, num_images_per_prompt, False ) ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1) encoder_hidden_states = encoder_hidden_states + global_embeds * ramp cak = dict(cond_lat=cond_lat) if hasattr(self.unet, "controlnet"): cak['control_depth'] = depth_image latents: torch.Tensor = super().__call__( None, *args, cross_attention_kwargs=cak, guidance_scale=guidance_scale, num_images_per_prompt=num_images_per_prompt, prompt_embeds=encoder_hidden_states, num_inference_steps=num_inference_steps, output_type='latent', width=width, height=height, **kwargs ).images latents = unscale_latents(latents) if not output_type == "latent": image = unscale_image(self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]) else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) return ImagePipelineOutput(images=image)