Spaces:
Runtime error
Runtime error
Commit
Β·
82ef366
1
Parent(s):
991d8d3
initial commit
Browse files- appearance_transfer_model.py +177 -0
- config.py +66 -0
- constants.py +3 -0
- demo.py +96 -0
- environment/environment.yaml +10 -0
- environment/requirements.txt +17 -0
- inputs/chocolate_cake.jpg +0 -0
- inputs/duomo.png +0 -0
- inputs/giraffe.png +0 -0
- inputs/red_velvet_cake.jpg +0 -0
- inputs/taj_mahal.jpg +0 -0
- inputs/zebra.png +0 -0
- models/__init__.py +0 -0
- models/stable_diffusion.py +240 -0
- models/unet_2d_condition.py +345 -0
- utils/__init__.py +0 -0
- utils/adain.py +45 -0
- utils/attention_utils.py +37 -0
- utils/ddpm_inversion.py +323 -0
- utils/image_utils.py +59 -0
- utils/latent_utils.py +81 -0
- utils/model_utils.py +16 -0
- utils/segmentation.py +111 -0
appearance_transfer_model.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Callable
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from config import RunConfig
|
| 7 |
+
from constants import OUT_INDEX, STRUCT_INDEX, STYLE_INDEX
|
| 8 |
+
from models.stable_diffusion import CrossImageAttentionStableDiffusionPipeline
|
| 9 |
+
from utils import attention_utils
|
| 10 |
+
from utils.adain import masked_adain
|
| 11 |
+
from utils.model_utils import get_stable_diffusion_model
|
| 12 |
+
from utils.segmentation import Segmentor
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class AppearanceTransferModel:
|
| 16 |
+
|
| 17 |
+
def __init__(self, config: RunConfig, pipe: Optional[CrossImageAttentionStableDiffusionPipeline] = None):
|
| 18 |
+
self.config = config
|
| 19 |
+
self.pipe = get_stable_diffusion_model() if pipe is None else pipe
|
| 20 |
+
self.register_attention_control()
|
| 21 |
+
self.segmentor = Segmentor(prompt=config.prompt, object_nouns=[config.object_noun])
|
| 22 |
+
self.latents_app, self.latents_struct = None, None
|
| 23 |
+
self.zs_app, self.zs_struct = None, None
|
| 24 |
+
self.image_app_mask_32, self.image_app_mask_64 = None, None
|
| 25 |
+
self.image_struct_mask_32, self.image_struct_mask_64 = None, None
|
| 26 |
+
self.enable_edit = False
|
| 27 |
+
self.step = 0
|
| 28 |
+
|
| 29 |
+
def set_latents(self, latents_app: torch.Tensor, latents_struct: torch.Tensor):
|
| 30 |
+
self.latents_app = latents_app
|
| 31 |
+
self.latents_struct = latents_struct
|
| 32 |
+
|
| 33 |
+
def set_noise(self, zs_app: torch.Tensor, zs_struct: torch.Tensor):
|
| 34 |
+
self.zs_app = zs_app
|
| 35 |
+
self.zs_struct = zs_struct
|
| 36 |
+
|
| 37 |
+
def set_masks(self, masks: List[torch.Tensor]):
|
| 38 |
+
self.image_app_mask_32, self.image_struct_mask_32, self.image_app_mask_64, self.image_struct_mask_64 = masks
|
| 39 |
+
|
| 40 |
+
def get_adain_callback(self):
|
| 41 |
+
|
| 42 |
+
def callback(st: int, timestep: int, latents: torch.FloatTensor) -> Callable:
|
| 43 |
+
self.step = st
|
| 44 |
+
# Compute the masks using prompt mixing self-segmentation and use the masks for AdaIN operation
|
| 45 |
+
if self.step == self.config.adain_range.start:
|
| 46 |
+
masks = self.segmentor.get_object_masks()
|
| 47 |
+
self.set_masks(masks)
|
| 48 |
+
# Apply AdaIN operation using the computed masks
|
| 49 |
+
if self.config.adain_range.start <= self.step < self.config.adain_range.end:
|
| 50 |
+
latents[0] = masked_adain(latents[0], latents[1], self.image_struct_mask_64, self.image_app_mask_64)
|
| 51 |
+
|
| 52 |
+
return callback
|
| 53 |
+
|
| 54 |
+
def register_attention_control(self):
|
| 55 |
+
|
| 56 |
+
model_self = self
|
| 57 |
+
|
| 58 |
+
class AttentionProcessor:
|
| 59 |
+
|
| 60 |
+
def __init__(self, place_in_unet: str):
|
| 61 |
+
self.place_in_unet = place_in_unet
|
| 62 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 63 |
+
raise ImportError("AttnProcessor2_0 requires torch 2.0, to use it, please upgrade torch to 2.0.")
|
| 64 |
+
|
| 65 |
+
def __call__(self,
|
| 66 |
+
attn,
|
| 67 |
+
hidden_states: torch.Tensor,
|
| 68 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 69 |
+
attention_mask=None,
|
| 70 |
+
temb=None,
|
| 71 |
+
perform_swap: bool = False):
|
| 72 |
+
|
| 73 |
+
residual = hidden_states
|
| 74 |
+
|
| 75 |
+
if attn.spatial_norm is not None:
|
| 76 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
| 77 |
+
|
| 78 |
+
input_ndim = hidden_states.ndim
|
| 79 |
+
|
| 80 |
+
if input_ndim == 4:
|
| 81 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 82 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 83 |
+
|
| 84 |
+
batch_size, sequence_length, _ = (
|
| 85 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
if attention_mask is not None:
|
| 89 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 90 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
| 91 |
+
|
| 92 |
+
if attn.group_norm is not None:
|
| 93 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 94 |
+
|
| 95 |
+
query = attn.to_q(hidden_states)
|
| 96 |
+
|
| 97 |
+
is_cross = encoder_hidden_states is not None
|
| 98 |
+
if not is_cross:
|
| 99 |
+
encoder_hidden_states = hidden_states
|
| 100 |
+
elif attn.norm_cross:
|
| 101 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 102 |
+
|
| 103 |
+
key = attn.to_k(encoder_hidden_states)
|
| 104 |
+
value = attn.to_v(encoder_hidden_states)
|
| 105 |
+
|
| 106 |
+
inner_dim = key.shape[-1]
|
| 107 |
+
head_dim = inner_dim // attn.heads
|
| 108 |
+
should_mix = False
|
| 109 |
+
|
| 110 |
+
# Potentially apply our cross image attention operation
|
| 111 |
+
# To do so, we need to be in a self-attention alyer in the decoder part of the denoising network
|
| 112 |
+
if perform_swap and not is_cross and "up" in self.place_in_unet and model_self.enable_edit:
|
| 113 |
+
if attention_utils.should_mix_keys_and_values(model_self, hidden_states):
|
| 114 |
+
should_mix = True
|
| 115 |
+
if model_self.step % 5 == 0 and model_self.step < 40:
|
| 116 |
+
# Inject the structure's keys and values
|
| 117 |
+
key[OUT_INDEX] = key[STRUCT_INDEX]
|
| 118 |
+
value[OUT_INDEX] = value[STRUCT_INDEX]
|
| 119 |
+
else:
|
| 120 |
+
# Inject the appearance's keys and values
|
| 121 |
+
key[OUT_INDEX] = key[STYLE_INDEX]
|
| 122 |
+
value[OUT_INDEX] = value[STYLE_INDEX]
|
| 123 |
+
|
| 124 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 125 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 126 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 127 |
+
|
| 128 |
+
# Compute the cross attention and apply our contrasting operation
|
| 129 |
+
hidden_states, attn_weight = attention_utils.compute_scaled_dot_product_attention(
|
| 130 |
+
query, key, value,
|
| 131 |
+
edit_map=perform_swap and model_self.enable_edit and should_mix,
|
| 132 |
+
is_cross=is_cross,
|
| 133 |
+
contrast_strength=model_self.config.contrast_strength,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# Update attention map for segmentation
|
| 137 |
+
if model_self.config.use_masked_adain and model_self.step == model_self.config.adain_range.start - 1:
|
| 138 |
+
model_self.segmentor.update_attention(attn_weight, is_cross)
|
| 139 |
+
|
| 140 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 141 |
+
hidden_states = hidden_states.to(query[OUT_INDEX].dtype)
|
| 142 |
+
|
| 143 |
+
# linear proj
|
| 144 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 145 |
+
# dropout
|
| 146 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 147 |
+
|
| 148 |
+
if input_ndim == 4:
|
| 149 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 150 |
+
|
| 151 |
+
if attn.residual_connection:
|
| 152 |
+
hidden_states = hidden_states + residual
|
| 153 |
+
|
| 154 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 155 |
+
|
| 156 |
+
return hidden_states
|
| 157 |
+
|
| 158 |
+
def register_recr(net_, count, place_in_unet):
|
| 159 |
+
if net_.__class__.__name__ == 'ResnetBlock2D':
|
| 160 |
+
pass
|
| 161 |
+
if net_.__class__.__name__ == 'Attention':
|
| 162 |
+
net_.set_processor(AttentionProcessor(place_in_unet + f"_{count + 1}"))
|
| 163 |
+
return count + 1
|
| 164 |
+
elif hasattr(net_, 'children'):
|
| 165 |
+
for net__ in net_.children():
|
| 166 |
+
count = register_recr(net__, count, place_in_unet)
|
| 167 |
+
return count
|
| 168 |
+
|
| 169 |
+
cross_att_count = 0
|
| 170 |
+
sub_nets = self.pipe.unet.named_children()
|
| 171 |
+
for net in sub_nets:
|
| 172 |
+
if "down" in net[0]:
|
| 173 |
+
cross_att_count += register_recr(net[1], 0, "down")
|
| 174 |
+
elif "up" in net[0]:
|
| 175 |
+
cross_att_count += register_recr(net[1], 0, "up")
|
| 176 |
+
elif "mid" in net[0]:
|
| 177 |
+
cross_att_count += register_recr(net[1], 0, "mid")
|
config.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import NamedTuple, Optional
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Range(NamedTuple):
|
| 7 |
+
start: int
|
| 8 |
+
end: int
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class RunConfig:
|
| 13 |
+
# Appearance image path
|
| 14 |
+
app_image_path: Path
|
| 15 |
+
# Struct image path
|
| 16 |
+
struct_image_path: Path
|
| 17 |
+
# Domain name (e.g., buildings, animals)
|
| 18 |
+
domain_name: Optional[str] = None
|
| 19 |
+
# Output path
|
| 20 |
+
output_path: Path = Path('./output')
|
| 21 |
+
# Random seed
|
| 22 |
+
seed: int = 42
|
| 23 |
+
# Input prompt for inversion (will use domain name as default)
|
| 24 |
+
prompt: Optional[str] = None
|
| 25 |
+
# Number of timesteps
|
| 26 |
+
num_timesteps: int = 100
|
| 27 |
+
# Whether to use a binary mask for performing AdaIN
|
| 28 |
+
use_masked_adain: bool = True
|
| 29 |
+
# Timesteps to apply cross-attention on 64x64 layers
|
| 30 |
+
cross_attn_64_range: Range = Range(start=10, end=90)
|
| 31 |
+
# Timesteps to apply cross-attention on 32x32 layers
|
| 32 |
+
cross_attn_32_range: Range = Range(start=10, end=70)
|
| 33 |
+
# Timesteps to apply AdaIn
|
| 34 |
+
adain_range: Range = Range(start=20, end=100)
|
| 35 |
+
# Guidance scale
|
| 36 |
+
guidance_scale: float = 7.5
|
| 37 |
+
# Swap guidance scale
|
| 38 |
+
swap_guidance_scale: float = 3.5
|
| 39 |
+
# Attention contrasting strength
|
| 40 |
+
contrast_strength: float = 1.67
|
| 41 |
+
# Object nouns to use for self-segmentation (will use the domain name as default)
|
| 42 |
+
object_noun: Optional[str] = None
|
| 43 |
+
# Whether to load previously saved inverted latent codes
|
| 44 |
+
load_latents: bool = True
|
| 45 |
+
# Number of steps to skip in the denoising process (used value from original edit-friendly DDPM paper)
|
| 46 |
+
skip_steps: int = 32
|
| 47 |
+
|
| 48 |
+
def __post_init__(self):
|
| 49 |
+
self.output_path = self.output_path / self.domain_name
|
| 50 |
+
self.output_path.mkdir(parents=True, exist_ok=True)
|
| 51 |
+
|
| 52 |
+
# Handle the domain name, prompt, and object nouns used for masking, etc.
|
| 53 |
+
if self.use_masked_adain and self.domain_name is None:
|
| 54 |
+
raise ValueError("Must provide --domain_name and --prompt when using masked AdaIN")
|
| 55 |
+
if not self.use_masked_adain and self.domain_name is None:
|
| 56 |
+
self.domain_name = "object"
|
| 57 |
+
if self.prompt is None:
|
| 58 |
+
self.prompt = f"A photo of a {self.domain_name}"
|
| 59 |
+
if self.object_noun is None:
|
| 60 |
+
self.object_noun = self.domain_name
|
| 61 |
+
|
| 62 |
+
# Define the paths to store the inverted latents to
|
| 63 |
+
self.latents_path = Path(self.output_path) / "latents"
|
| 64 |
+
self.latents_path.mkdir(parents=True, exist_ok=True)
|
| 65 |
+
self.app_latent_save_path = self.latents_path / f"{self.app_image_path.stem}.pt"
|
| 66 |
+
self.struct_latent_save_path = self.latents_path / f"{self.struct_image_path.stem}.pt"
|
constants.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
OUT_INDEX = 0
|
| 2 |
+
STYLE_INDEX = 1
|
| 3 |
+
STRUCT_INDEX = 2
|
demo.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
from appearance_transfer_model import AppearanceTransferModel
|
| 9 |
+
from run import run_appearance_transfer
|
| 10 |
+
from utils.latent_utils import load_latents_or_invert_images
|
| 11 |
+
from utils.model_utils import get_stable_diffusion_model
|
| 12 |
+
|
| 13 |
+
sys.path.append(".")
|
| 14 |
+
sys.path.append("..")
|
| 15 |
+
|
| 16 |
+
from config import RunConfig
|
| 17 |
+
|
| 18 |
+
DESCRIPTION = '''
|
| 19 |
+
<h1 style="text-align: center;"> Cross-Image Attention for Zero-Shot Appearance Transfer </h1>
|
| 20 |
+
<p style="text-align: center;">
|
| 21 |
+
This is a demo for our <a href="https://arxiv.org/abs/2311.03335">paper</a>:
|
| 22 |
+
''Cross-Image Attention for Zero-Shot Appearance Transfer''.
|
| 23 |
+
<br>
|
| 24 |
+
Given two images depicting a source structure and a target appearance, our method generates an image merging
|
| 25 |
+
the structure of one image with the appearance of the other.
|
| 26 |
+
<br>
|
| 27 |
+
We do so in a zero-shot manner, with no optimization or model training required while supporting appearance
|
| 28 |
+
transfer across images that may differ in size and shape.
|
| 29 |
+
</p>
|
| 30 |
+
'''
|
| 31 |
+
|
| 32 |
+
pipe = get_stable_diffusion_model()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def main_pipeline(app_image_path: str,
|
| 36 |
+
struct_image_path: str,
|
| 37 |
+
domain_name: str,
|
| 38 |
+
seed: int,
|
| 39 |
+
prompt: Optional[str] = None) -> Image.Image:
|
| 40 |
+
if prompt == "":
|
| 41 |
+
prompt = None
|
| 42 |
+
config = RunConfig(
|
| 43 |
+
app_image_path=Path(app_image_path),
|
| 44 |
+
struct_image_path=Path(struct_image_path),
|
| 45 |
+
domain_name=domain_name,
|
| 46 |
+
prompt=prompt,
|
| 47 |
+
seed=seed,
|
| 48 |
+
load_latents=False
|
| 49 |
+
)
|
| 50 |
+
model = AppearanceTransferModel(config=config, pipe=pipe)
|
| 51 |
+
latents_app, latents_struct, noise_app, noise_struct = load_latents_or_invert_images(model=model, cfg=config)
|
| 52 |
+
model.set_latents(latents_app, latents_struct)
|
| 53 |
+
model.set_noise(noise_app, noise_struct)
|
| 54 |
+
print("Running appearance transfer...")
|
| 55 |
+
images = run_appearance_transfer(model=model, cfg=config)
|
| 56 |
+
print("Done.")
|
| 57 |
+
return [images[0]]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
with gr.Blocks(css='style.css') as demo:
|
| 61 |
+
gr.Markdown(DESCRIPTION)
|
| 62 |
+
|
| 63 |
+
gr.HTML('''<a href="https://huggingface.co/spaces/yuvalalaluf/cross-image-attention?duplicate=true"><img src="https://bit.ly/3gLdBN6"
|
| 64 |
+
alt="Duplicate Space"></a>''')
|
| 65 |
+
|
| 66 |
+
with gr.Row():
|
| 67 |
+
with gr.Column():
|
| 68 |
+
app_image_path = gr.Image(label="Upload appearance image", type="filepath")
|
| 69 |
+
struct_image_path = gr.Image(label="Upload structure image", type="filepath")
|
| 70 |
+
domain_name = gr.Text(label="Domain name", max_lines=1,
|
| 71 |
+
info="Specifies the domain the objects are coming from (e.g., 'animal', 'building', etc).")
|
| 72 |
+
prompt = gr.Text(label="Prompt to use for inversion.", value='',
|
| 73 |
+
info='If this kept empty, we will use the domain name to define '
|
| 74 |
+
'the prompt as "A photo of a <domain_name>".')
|
| 75 |
+
random_seed = gr.Number(value=42, label="Random seed", precision=0)
|
| 76 |
+
run_button = gr.Button('Generate')
|
| 77 |
+
|
| 78 |
+
with gr.Column():
|
| 79 |
+
result = gr.Gallery(label='Result')
|
| 80 |
+
inputs = [app_image_path, struct_image_path, domain_name, random_seed, prompt]
|
| 81 |
+
outputs = [result]
|
| 82 |
+
run_button.click(fn=main_pipeline, inputs=inputs, outputs=outputs)
|
| 83 |
+
|
| 84 |
+
with gr.Row():
|
| 85 |
+
examples = [
|
| 86 |
+
['inputs/zebra.png', 'inputs/giraffe.png', 'animal', 20, None],
|
| 87 |
+
['inputs/taj_mahal.jpg', 'inputs/duomo.png', 'building', 42, None],
|
| 88 |
+
['inputs/red_velvet_cake.jpg', 'inputs/chocolate_cake.jpg', 'cake', 42, 'A photo of cake'],
|
| 89 |
+
]
|
| 90 |
+
gr.Examples(examples=examples,
|
| 91 |
+
inputs=[app_image_path, struct_image_path, domain_name, random_seed, prompt],
|
| 92 |
+
outputs=[result],
|
| 93 |
+
fn=main_pipeline,
|
| 94 |
+
cache_examples=True)
|
| 95 |
+
|
| 96 |
+
demo.launch(share=False, server_name="127.0.0.1", server_port=8888)
|
environment/environment.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: cross_image
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- python=3.8.5
|
| 7 |
+
- pip=20.3
|
| 8 |
+
- cudatoolkit=11.3
|
| 9 |
+
- pip:
|
| 10 |
+
- -r requirements.txt
|
environment/requirements.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
matplotlib==3.6.3
|
| 2 |
+
matplotlib-inline==0.1.6
|
| 3 |
+
jupyter==1.0.0
|
| 4 |
+
numpy==1.24.1
|
| 5 |
+
pyrallis==0.3.1
|
| 6 |
+
torch==2.0.1
|
| 7 |
+
torchvision==0.15.2
|
| 8 |
+
diffusers==0.19.3
|
| 9 |
+
transformers==4.30.2
|
| 10 |
+
accelerate==0.20.3
|
| 11 |
+
huggingface-hub==0.16.4
|
| 12 |
+
xformers==0.0.21
|
| 13 |
+
tokenizers==0.13.3
|
| 14 |
+
nltk==3.8.1
|
| 15 |
+
Pillow==10.1.0
|
| 16 |
+
scikit_learn==1.3.0
|
| 17 |
+
tqdm==4.64.1
|
inputs/chocolate_cake.jpg
ADDED
|
inputs/duomo.png
ADDED
|
inputs/giraffe.png
ADDED
|
inputs/red_velvet_cake.jpg
ADDED
|
inputs/taj_mahal.jpg
ADDED
|
inputs/zebra.png
ADDED
|
models/__init__.py
ADDED
|
File without changes
|
models/stable_diffusion.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from diffusers import StableDiffusionPipeline
|
| 6 |
+
from diffusers.models import AutoencoderKL
|
| 7 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
| 8 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
|
| 9 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
|
| 12 |
+
|
| 13 |
+
from config import Range
|
| 14 |
+
from models.unet_2d_condition import FreeUUNet2DConditionModel
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class CrossImageAttentionStableDiffusionPipeline(StableDiffusionPipeline):
|
| 18 |
+
""" A modification of the standard StableDiffusionPipeline to incorporate our cross-image attention."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, vae: AutoencoderKL,
|
| 21 |
+
text_encoder: CLIPTextModel,
|
| 22 |
+
tokenizer: CLIPTokenizer,
|
| 23 |
+
unet: FreeUUNet2DConditionModel,
|
| 24 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 25 |
+
safety_checker: StableDiffusionSafetyChecker,
|
| 26 |
+
feature_extractor: CLIPImageProcessor,
|
| 27 |
+
requires_safety_checker: bool = True):
|
| 28 |
+
super().__init__(
|
| 29 |
+
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
@torch.no_grad()
|
| 33 |
+
def __call__(
|
| 34 |
+
self,
|
| 35 |
+
prompt: Union[str, List[str]] = None,
|
| 36 |
+
height: Optional[int] = None,
|
| 37 |
+
width: Optional[int] = None,
|
| 38 |
+
num_inference_steps: int = 50,
|
| 39 |
+
guidance_scale: float = 7.5,
|
| 40 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 41 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 42 |
+
eta: float = 0.0,
|
| 43 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 44 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 45 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 46 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 47 |
+
output_type: Optional[str] = "pil",
|
| 48 |
+
return_dict: bool = True,
|
| 49 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 50 |
+
callback_steps: int = 1,
|
| 51 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 52 |
+
guidance_rescale: float = 0.0,
|
| 53 |
+
swap_guidance_scale: float = 1.0,
|
| 54 |
+
cross_image_attention_range: Range = Range(10, 90),
|
| 55 |
+
# DDPM addition
|
| 56 |
+
zs: Optional[List[torch.Tensor]] = None
|
| 57 |
+
):
|
| 58 |
+
|
| 59 |
+
# 0. Default height and width to unet
|
| 60 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 61 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 62 |
+
|
| 63 |
+
# 1. Check inputs. Raise error if not correct
|
| 64 |
+
self.check_inputs(
|
| 65 |
+
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# 2. Define call parameters
|
| 69 |
+
if prompt is not None and isinstance(prompt, str):
|
| 70 |
+
batch_size = 1
|
| 71 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 72 |
+
batch_size = len(prompt)
|
| 73 |
+
else:
|
| 74 |
+
batch_size = prompt_embeds.shape[0]
|
| 75 |
+
|
| 76 |
+
device = self._execution_device
|
| 77 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 78 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 79 |
+
# corresponds to doing no classifier free guidance.
|
| 80 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 81 |
+
|
| 82 |
+
# 3. Encode input prompt
|
| 83 |
+
text_encoder_lora_scale = (
|
| 84 |
+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
| 85 |
+
)
|
| 86 |
+
prompt_embeds = self._encode_prompt(
|
| 87 |
+
prompt,
|
| 88 |
+
device,
|
| 89 |
+
num_images_per_prompt,
|
| 90 |
+
do_classifier_free_guidance,
|
| 91 |
+
negative_prompt,
|
| 92 |
+
prompt_embeds=prompt_embeds,
|
| 93 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 94 |
+
lora_scale=text_encoder_lora_scale,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# 4. Prepare timesteps
|
| 98 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 99 |
+
timesteps = self.scheduler.timesteps
|
| 100 |
+
t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs[0].shape[0]:])}
|
| 101 |
+
timesteps = timesteps[-zs[0].shape[0]:]
|
| 102 |
+
|
| 103 |
+
# 5. Prepare latent variables
|
| 104 |
+
num_channels_latents = self.unet.config.in_channels
|
| 105 |
+
latents = self.prepare_latents(
|
| 106 |
+
batch_size * num_images_per_prompt,
|
| 107 |
+
num_channels_latents,
|
| 108 |
+
height,
|
| 109 |
+
width,
|
| 110 |
+
prompt_embeds.dtype,
|
| 111 |
+
device,
|
| 112 |
+
generator,
|
| 113 |
+
latents,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# 7. Denoising loop
|
| 117 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 118 |
+
|
| 119 |
+
op = tqdm(timesteps[-zs[0].shape[0]:])
|
| 120 |
+
n_timesteps = len(timesteps[-zs[0].shape[0]:])
|
| 121 |
+
|
| 122 |
+
count = 0
|
| 123 |
+
for t in op:
|
| 124 |
+
i = t_to_idx[int(t)]
|
| 125 |
+
|
| 126 |
+
# expand the latents if we are doing classifier free guidance
|
| 127 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 128 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 129 |
+
|
| 130 |
+
noise_pred_swap = self.unet(
|
| 131 |
+
latent_model_input,
|
| 132 |
+
t,
|
| 133 |
+
encoder_hidden_states=prompt_embeds,
|
| 134 |
+
cross_attention_kwargs={'perform_swap': True},
|
| 135 |
+
return_dict=False,
|
| 136 |
+
)[0]
|
| 137 |
+
noise_pred_no_swap = self.unet(
|
| 138 |
+
latent_model_input,
|
| 139 |
+
t,
|
| 140 |
+
encoder_hidden_states=prompt_embeds,
|
| 141 |
+
cross_attention_kwargs={'perform_swap': False},
|
| 142 |
+
return_dict=False,
|
| 143 |
+
)[0]
|
| 144 |
+
|
| 145 |
+
# perform guidance
|
| 146 |
+
if do_classifier_free_guidance:
|
| 147 |
+
_, noise_swap_pred_text = noise_pred_swap.chunk(2)
|
| 148 |
+
noise_no_swap_pred_uncond, _ = noise_pred_no_swap.chunk(2)
|
| 149 |
+
noise_pred = noise_no_swap_pred_uncond + guidance_scale * (
|
| 150 |
+
noise_swap_pred_text - noise_no_swap_pred_uncond)
|
| 151 |
+
else:
|
| 152 |
+
is_cross_image_step = cross_image_attention_range.start <= i <= cross_image_attention_range.end
|
| 153 |
+
if swap_guidance_scale > 1.0 and is_cross_image_step:
|
| 154 |
+
swapping_strengths = np.linspace(swap_guidance_scale,
|
| 155 |
+
max(swap_guidance_scale / 2, 1.0),
|
| 156 |
+
n_timesteps)
|
| 157 |
+
swapping_strength = swapping_strengths[count]
|
| 158 |
+
noise_pred = noise_pred_no_swap + swapping_strength * (noise_pred_swap - noise_pred_no_swap)
|
| 159 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_swap, guidance_rescale=guidance_rescale)
|
| 160 |
+
else:
|
| 161 |
+
noise_pred = noise_pred_swap
|
| 162 |
+
|
| 163 |
+
latents = torch.stack([
|
| 164 |
+
self.perform_ddpm_step(t_to_idx, zs[latent_idx], latents[latent_idx], t, noise_pred[latent_idx], eta)
|
| 165 |
+
for latent_idx in range(latents.shape[0])
|
| 166 |
+
])
|
| 167 |
+
|
| 168 |
+
# call the callback, if provided
|
| 169 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 170 |
+
# progress_bar.update()
|
| 171 |
+
if callback is not None and i % callback_steps == 0:
|
| 172 |
+
callback(i, t, latents)
|
| 173 |
+
|
| 174 |
+
count += 1
|
| 175 |
+
|
| 176 |
+
if not output_type == "latent":
|
| 177 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
| 178 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
| 179 |
+
else:
|
| 180 |
+
image = latents
|
| 181 |
+
has_nsfw_concept = None
|
| 182 |
+
|
| 183 |
+
if has_nsfw_concept is None:
|
| 184 |
+
do_denormalize = [True] * image.shape[0]
|
| 185 |
+
else:
|
| 186 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
| 187 |
+
|
| 188 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
| 189 |
+
|
| 190 |
+
# Offload last model to CPU
|
| 191 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
| 192 |
+
self.final_offload_hook.offload()
|
| 193 |
+
|
| 194 |
+
if not return_dict:
|
| 195 |
+
return (image, has_nsfw_concept)
|
| 196 |
+
|
| 197 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
| 198 |
+
|
| 199 |
+
def perform_ddpm_step(self, t_to_idx, zs, latents, t, noise_pred, eta):
|
| 200 |
+
idx = t_to_idx[int(t)]
|
| 201 |
+
z = zs[idx] if not zs is None else None
|
| 202 |
+
# 1. get previous step value (=t-1)
|
| 203 |
+
prev_timestep = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
|
| 204 |
+
# 2. compute alphas, betas
|
| 205 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[t]
|
| 206 |
+
alpha_prod_t_prev = self.scheduler.alphas_cumprod[
|
| 207 |
+
prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
|
| 208 |
+
beta_prod_t = 1 - alpha_prod_t
|
| 209 |
+
# 3. compute predicted original sample from predicted noise also called
|
| 210 |
+
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
| 211 |
+
pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
|
| 212 |
+
# 5. compute variance: "sigma_t(Ξ·)" -> see formula (16)
|
| 213 |
+
# Ο_t = sqrt((1 β Ξ±_tβ1)/(1 β Ξ±_t)) * sqrt(1 β Ξ±_t/Ξ±_tβ1)
|
| 214 |
+
# variance = self.scheduler._get_variance(timestep, prev_timestep)
|
| 215 |
+
variance = self.get_variance(t)
|
| 216 |
+
std_dev_t = eta * variance ** (0.5)
|
| 217 |
+
# Take care of asymetric reverse process (asyrp)
|
| 218 |
+
model_output_direction = noise_pred
|
| 219 |
+
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
| 220 |
+
# pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction
|
| 221 |
+
pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction
|
| 222 |
+
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
| 223 |
+
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
| 224 |
+
# 8. Add noice if eta > 0
|
| 225 |
+
if eta > 0:
|
| 226 |
+
if z is None:
|
| 227 |
+
z = torch.randn(noise_pred.shape, device=self.device)
|
| 228 |
+
sigma_z = eta * variance ** (0.5) * z
|
| 229 |
+
prev_sample = prev_sample + sigma_z
|
| 230 |
+
return prev_sample
|
| 231 |
+
|
| 232 |
+
def get_variance(self, timestep):
|
| 233 |
+
prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
|
| 234 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
|
| 235 |
+
alpha_prod_t_prev = self.scheduler.alphas_cumprod[
|
| 236 |
+
prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
|
| 237 |
+
beta_prod_t = 1 - alpha_prod_t
|
| 238 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
| 239 |
+
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
| 240 |
+
return variance
|
models/unet_2d_condition.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.utils.checkpoint
|
| 5 |
+
from diffusers import UNet2DConditionModel
|
| 6 |
+
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
| 7 |
+
from diffusers.utils import logging
|
| 8 |
+
from torch.fft import fftn, ifftn, fftshift, ifftshift
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
This is a small extension of the standard UNet2DConditionModel with the small addition of the
|
| 12 |
+
Free-U trick (https://github.com/ChenyangSi/FreeU).
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def Fourier_filter(x, threshold, scale):
|
| 19 |
+
# FFT
|
| 20 |
+
x_freq = fftn(x, dim=(-2, -1))
|
| 21 |
+
x_freq = fftshift(x_freq, dim=(-2, -1))
|
| 22 |
+
|
| 23 |
+
B, C, H, W = x_freq.shape
|
| 24 |
+
mask = torch.ones((B, C, H, W)).cuda() # CUDA iΓ§in
|
| 25 |
+
|
| 26 |
+
crow, ccol = H // 2, W // 2
|
| 27 |
+
mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
|
| 28 |
+
x_freq = x_freq * mask
|
| 29 |
+
|
| 30 |
+
# IFFT
|
| 31 |
+
x_freq = ifftshift(x_freq, dim=(-2, -1))
|
| 32 |
+
x_filtered = ifftn(x_freq, dim=(-2, -1)).real
|
| 33 |
+
|
| 34 |
+
return x_filtered
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class FreeUUNet2DConditionModel(UNet2DConditionModel):
|
| 38 |
+
|
| 39 |
+
def forward(
|
| 40 |
+
self,
|
| 41 |
+
sample: torch.FloatTensor,
|
| 42 |
+
timestep: Union[torch.Tensor, float, int],
|
| 43 |
+
encoder_hidden_states: torch.Tensor,
|
| 44 |
+
class_labels: Optional[torch.Tensor] = None,
|
| 45 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
| 46 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 47 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 48 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
| 49 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
| 50 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
| 51 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 52 |
+
return_dict: bool = True,
|
| 53 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
| 54 |
+
r"""
|
| 55 |
+
The [`UNet2DConditionModel`] forward method.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
sample (`torch.FloatTensor`):
|
| 59 |
+
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
| 60 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
| 61 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
| 62 |
+
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
| 63 |
+
encoder_attention_mask (`torch.Tensor`):
|
| 64 |
+
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
| 65 |
+
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
| 66 |
+
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
| 67 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 68 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
| 69 |
+
tuple.
|
| 70 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 71 |
+
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
| 72 |
+
added_cond_kwargs: (`dict`, *optional*):
|
| 73 |
+
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
| 74 |
+
are passed along to the UNet blocks.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
| 78 |
+
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
| 79 |
+
a `tuple` is returned where the first element is the sample tensor.
|
| 80 |
+
"""
|
| 81 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
| 82 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
| 83 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
| 84 |
+
# on the fly if necessary.
|
| 85 |
+
default_overall_up_factor = 2 ** self.num_upsamplers
|
| 86 |
+
|
| 87 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
| 88 |
+
forward_upsample_size = False
|
| 89 |
+
upsample_size = None
|
| 90 |
+
|
| 91 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
| 92 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
| 93 |
+
forward_upsample_size = True
|
| 94 |
+
|
| 95 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
| 96 |
+
# expects mask of shape:
|
| 97 |
+
# [batch, key_tokens]
|
| 98 |
+
# adds singleton query_tokens dimension:
|
| 99 |
+
# [batch, 1, key_tokens]
|
| 100 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
| 101 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
| 102 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
| 103 |
+
if attention_mask is not None:
|
| 104 |
+
# assume that mask is expressed as:
|
| 105 |
+
# (1 = keep, 0 = discard)
|
| 106 |
+
# convert mask into a bias that can be added to attention scores:
|
| 107 |
+
# (keep = +0, discard = -10000.0)
|
| 108 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
| 109 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 110 |
+
|
| 111 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
| 112 |
+
if encoder_attention_mask is not None:
|
| 113 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
| 114 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
| 115 |
+
|
| 116 |
+
# 0. center input if necessary
|
| 117 |
+
if self.config.center_input_sample:
|
| 118 |
+
sample = 2 * sample - 1.0
|
| 119 |
+
|
| 120 |
+
# 1. time
|
| 121 |
+
timesteps = timestep
|
| 122 |
+
if not torch.is_tensor(timesteps):
|
| 123 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
| 124 |
+
is_mps = sample.device.type == "mps"
|
| 125 |
+
if isinstance(timestep, float):
|
| 126 |
+
dtype = torch.float32 if is_mps else torch.float64
|
| 127 |
+
else:
|
| 128 |
+
dtype = torch.int32 if is_mps else torch.int64
|
| 129 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
| 130 |
+
elif len(timesteps.shape) == 0:
|
| 131 |
+
timesteps = timesteps[None].to(sample.device)
|
| 132 |
+
|
| 133 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 134 |
+
timesteps = timesteps.expand(sample.shape[0])
|
| 135 |
+
|
| 136 |
+
t_emb = self.time_proj(timesteps)
|
| 137 |
+
|
| 138 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
| 139 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 140 |
+
# there might be better ways to encapsulate this.
|
| 141 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
| 142 |
+
|
| 143 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
| 144 |
+
aug_emb = None
|
| 145 |
+
|
| 146 |
+
if self.class_embedding is not None:
|
| 147 |
+
if class_labels is None:
|
| 148 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
| 149 |
+
|
| 150 |
+
if self.config.class_embed_type == "timestep":
|
| 151 |
+
class_labels = self.time_proj(class_labels)
|
| 152 |
+
|
| 153 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
| 154 |
+
# there might be better ways to encapsulate this.
|
| 155 |
+
class_labels = class_labels.to(dtype=sample.dtype)
|
| 156 |
+
|
| 157 |
+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
| 158 |
+
|
| 159 |
+
if self.config.class_embeddings_concat:
|
| 160 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
| 161 |
+
else:
|
| 162 |
+
emb = emb + class_emb
|
| 163 |
+
|
| 164 |
+
if self.config.addition_embed_type == "text":
|
| 165 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
| 166 |
+
elif self.config.addition_embed_type == "text_image":
|
| 167 |
+
# Kandinsky 2.1 - style
|
| 168 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 169 |
+
raise ValueError(
|
| 170 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
| 174 |
+
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
| 175 |
+
aug_emb = self.add_embedding(text_embs, image_embs)
|
| 176 |
+
elif self.config.addition_embed_type == "text_time":
|
| 177 |
+
# SDXL - style
|
| 178 |
+
if "text_embeds" not in added_cond_kwargs:
|
| 179 |
+
raise ValueError(
|
| 180 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
| 181 |
+
)
|
| 182 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
| 183 |
+
if "time_ids" not in added_cond_kwargs:
|
| 184 |
+
raise ValueError(
|
| 185 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
| 186 |
+
)
|
| 187 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
| 188 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
| 189 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
| 190 |
+
|
| 191 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
| 192 |
+
add_embeds = add_embeds.to(emb.dtype)
|
| 193 |
+
aug_emb = self.add_embedding(add_embeds)
|
| 194 |
+
elif self.config.addition_embed_type == "image":
|
| 195 |
+
# Kandinsky 2.2 - style
|
| 196 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 197 |
+
raise ValueError(
|
| 198 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
| 199 |
+
)
|
| 200 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
| 201 |
+
aug_emb = self.add_embedding(image_embs)
|
| 202 |
+
elif self.config.addition_embed_type == "image_hint":
|
| 203 |
+
# Kandinsky 2.2 - style
|
| 204 |
+
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
| 205 |
+
raise ValueError(
|
| 206 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
| 207 |
+
)
|
| 208 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
| 209 |
+
hint = added_cond_kwargs.get("hint")
|
| 210 |
+
aug_emb, hint = self.add_embedding(image_embs, hint)
|
| 211 |
+
sample = torch.cat([sample, hint], dim=1)
|
| 212 |
+
|
| 213 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
| 214 |
+
|
| 215 |
+
if self.time_embed_act is not None:
|
| 216 |
+
emb = self.time_embed_act(emb)
|
| 217 |
+
|
| 218 |
+
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
| 219 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
| 220 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
| 221 |
+
# Kadinsky 2.1 - style
|
| 222 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 223 |
+
raise ValueError(
|
| 224 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
| 228 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
| 229 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
|
| 230 |
+
# Kandinsky 2.2 - style
|
| 231 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 232 |
+
raise ValueError(
|
| 233 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
| 234 |
+
)
|
| 235 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
| 236 |
+
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
| 237 |
+
# 2. pre-process
|
| 238 |
+
sample = self.conv_in(sample)
|
| 239 |
+
|
| 240 |
+
# 3. down
|
| 241 |
+
|
| 242 |
+
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
| 243 |
+
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
|
| 244 |
+
|
| 245 |
+
down_block_res_samples = (sample,)
|
| 246 |
+
for downsample_block in self.down_blocks:
|
| 247 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
| 248 |
+
# For t2i-adapter CrossAttnDownBlock2D
|
| 249 |
+
additional_residuals = {}
|
| 250 |
+
if is_adapter and len(down_block_additional_residuals) > 0:
|
| 251 |
+
additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
|
| 252 |
+
|
| 253 |
+
sample, res_samples = downsample_block(
|
| 254 |
+
hidden_states=sample,
|
| 255 |
+
temb=emb,
|
| 256 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 257 |
+
attention_mask=attention_mask,
|
| 258 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 259 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 260 |
+
**additional_residuals,
|
| 261 |
+
)
|
| 262 |
+
else:
|
| 263 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
| 264 |
+
|
| 265 |
+
if is_adapter and len(down_block_additional_residuals) > 0:
|
| 266 |
+
sample += down_block_additional_residuals.pop(0)
|
| 267 |
+
|
| 268 |
+
down_block_res_samples += res_samples
|
| 269 |
+
|
| 270 |
+
if is_controlnet:
|
| 271 |
+
new_down_block_res_samples = ()
|
| 272 |
+
|
| 273 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
| 274 |
+
down_block_res_samples, down_block_additional_residuals
|
| 275 |
+
):
|
| 276 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
| 277 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
| 278 |
+
|
| 279 |
+
down_block_res_samples = new_down_block_res_samples
|
| 280 |
+
|
| 281 |
+
# 4. mid
|
| 282 |
+
if self.mid_block is not None:
|
| 283 |
+
sample = self.mid_block(
|
| 284 |
+
sample,
|
| 285 |
+
emb,
|
| 286 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 287 |
+
attention_mask=attention_mask,
|
| 288 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 289 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
if is_controlnet:
|
| 293 |
+
sample = sample + mid_block_additional_residual
|
| 294 |
+
|
| 295 |
+
# 5. up
|
| 296 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
| 297 |
+
is_final_block = i == len(self.up_blocks) - 1
|
| 298 |
+
|
| 299 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
| 300 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
| 301 |
+
|
| 302 |
+
# Add the Free-U trick here!
|
| 303 |
+
# Fourier Filter
|
| 304 |
+
if sample.shape[1] == 1280:
|
| 305 |
+
sample[:, :640] *= 1.2 # 1.1 # For SD2.1
|
| 306 |
+
sample = Fourier_filter(sample, threshold=1, scale=0.9)
|
| 307 |
+
|
| 308 |
+
if sample.shape[1] == 640:
|
| 309 |
+
sample[:, :320] *= 1.4 # 1.2 # For SD2.1
|
| 310 |
+
sample = Fourier_filter(sample, threshold=1, scale=0.2)
|
| 311 |
+
|
| 312 |
+
# if we have not reached the final block and need to forward the
|
| 313 |
+
# upsample size, we do it here
|
| 314 |
+
if not is_final_block and forward_upsample_size:
|
| 315 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
| 316 |
+
|
| 317 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
| 318 |
+
sample = upsample_block(
|
| 319 |
+
hidden_states=sample,
|
| 320 |
+
temb=emb,
|
| 321 |
+
res_hidden_states_tuple=res_samples,
|
| 322 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 323 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 324 |
+
upsample_size=upsample_size,
|
| 325 |
+
attention_mask=attention_mask,
|
| 326 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 327 |
+
)
|
| 328 |
+
else:
|
| 329 |
+
sample = upsample_block(
|
| 330 |
+
hidden_states=sample,
|
| 331 |
+
temb=emb,
|
| 332 |
+
res_hidden_states_tuple=res_samples,
|
| 333 |
+
upsample_size=upsample_size
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
# 6. post-process
|
| 337 |
+
if self.conv_norm_out:
|
| 338 |
+
sample = self.conv_norm_out(sample)
|
| 339 |
+
sample = self.conv_act(sample)
|
| 340 |
+
sample = self.conv_out(sample)
|
| 341 |
+
|
| 342 |
+
if not return_dict:
|
| 343 |
+
return (sample,)
|
| 344 |
+
|
| 345 |
+
return UNet2DConditionOutput(sample=sample)
|
utils/__init__.py
ADDED
|
File without changes
|
utils/adain.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def masked_adain(content_feat, style_feat, content_mask, style_mask):
|
| 2 |
+
assert (content_feat.size()[:2] == style_feat.size()[:2])
|
| 3 |
+
size = content_feat.size()
|
| 4 |
+
style_mean, style_std = calc_mean_std(style_feat, mask=style_mask)
|
| 5 |
+
content_mean, content_std = calc_mean_std(content_feat, mask=content_mask)
|
| 6 |
+
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
| 7 |
+
style_normalized_feat = normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
| 8 |
+
return content_feat * (1 - content_mask) + style_normalized_feat * content_mask
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def calc_mean_std(feat, eps=1e-5, mask=None):
|
| 12 |
+
# eps is a small value added to the variance to avoid divide-by-zero.
|
| 13 |
+
size = feat.size()
|
| 14 |
+
if len(size) == 2:
|
| 15 |
+
return calc_mean_std_2d(feat, eps, mask)
|
| 16 |
+
|
| 17 |
+
assert (len(size) == 3)
|
| 18 |
+
C = size[0]
|
| 19 |
+
if mask is not None:
|
| 20 |
+
feat_var = feat.view(C, -1)[:, mask.view(-1) == 1].var(dim=1) + eps
|
| 21 |
+
feat_std = feat_var.sqrt().view(C, 1, 1)
|
| 22 |
+
feat_mean = feat.view(C, -1)[:, mask.view(-1) == 1].mean(dim=1).view(C, 1, 1)
|
| 23 |
+
else:
|
| 24 |
+
feat_var = feat.view(C, -1).var(dim=1) + eps
|
| 25 |
+
feat_std = feat_var.sqrt().view(C, 1, 1)
|
| 26 |
+
feat_mean = feat.view(C, -1).mean(dim=1).view(C, 1, 1)
|
| 27 |
+
|
| 28 |
+
return feat_mean, feat_std
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def calc_mean_std_2d(feat, eps=1e-5, mask=None):
|
| 32 |
+
# eps is a small value added to the variance to avoid divide-by-zero.
|
| 33 |
+
size = feat.size()
|
| 34 |
+
assert (len(size) == 2)
|
| 35 |
+
C = size[0]
|
| 36 |
+
if mask is not None:
|
| 37 |
+
feat_var = feat.view(C, -1)[:, mask.view(-1) == 1].var(dim=1) + eps
|
| 38 |
+
feat_std = feat_var.sqrt().view(C, 1)
|
| 39 |
+
feat_mean = feat.view(C, -1)[:, mask.view(-1) == 1].mean(dim=1).view(C, 1)
|
| 40 |
+
else:
|
| 41 |
+
feat_var = feat.view(C, -1).var(dim=1) + eps
|
| 42 |
+
feat_std = feat_var.sqrt().view(C, 1)
|
| 43 |
+
feat_mean = feat.view(C, -1).mean(dim=1).view(C, 1)
|
| 44 |
+
|
| 45 |
+
return feat_mean, feat_std
|
utils/attention_utils.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from constants import OUT_INDEX
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def should_mix_keys_and_values(model, hidden_states: torch.Tensor) -> bool:
|
| 8 |
+
""" Verify whether we should perform the mixing in the current timestep. """
|
| 9 |
+
is_in_32_timestep_range = (
|
| 10 |
+
model.config.cross_attn_32_range.start <= model.step < model.config.cross_attn_32_range.end
|
| 11 |
+
)
|
| 12 |
+
is_in_64_timestep_range = (
|
| 13 |
+
model.config.cross_attn_64_range.start <= model.step < model.config.cross_attn_64_range.end
|
| 14 |
+
)
|
| 15 |
+
is_hidden_states_32_square = (hidden_states.shape[1] == 32 ** 2)
|
| 16 |
+
is_hidden_states_64_square = (hidden_states.shape[1] == 64 ** 2)
|
| 17 |
+
should_mix = (is_in_32_timestep_range and is_hidden_states_32_square) or \
|
| 18 |
+
(is_in_64_timestep_range and is_hidden_states_64_square)
|
| 19 |
+
return should_mix
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def compute_scaled_dot_product_attention(Q, K, V, edit_map=False, is_cross=False, contrast_strength=1.0):
|
| 23 |
+
""" Compute the scale dot product attention, potentially with our contrasting operation. """
|
| 24 |
+
attn_weight = torch.softmax((Q @ K.transpose(-2, -1) / math.sqrt(Q.size(-1))), dim=-1)
|
| 25 |
+
if edit_map and not is_cross:
|
| 26 |
+
attn_weight[OUT_INDEX] = torch.stack([
|
| 27 |
+
torch.clip(enhance_tensor(attn_weight[OUT_INDEX][head_idx], contrast_factor=contrast_strength),
|
| 28 |
+
min=0.0, max=1.0)
|
| 29 |
+
for head_idx in range(attn_weight.shape[1])
|
| 30 |
+
])
|
| 31 |
+
return attn_weight @ V, attn_weight
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def enhance_tensor(tensor: torch.Tensor, contrast_factor: float = 1.67) -> torch.Tensor:
|
| 35 |
+
""" Compute the attention map contrasting. """
|
| 36 |
+
adjusted_tensor = (tensor - tensor.mean(dim=-1)) * contrast_factor + tensor.mean(dim=-1)
|
| 37 |
+
return adjusted_tensor
|
utils/ddpm_inversion.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import inference_mode
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Inversion code taken from:
|
| 9 |
+
1. The official implementation of Edit-Friendly DDPM Inversion: https://github.com/inbarhub/DDPM_inversion
|
| 10 |
+
2. The LEDITS demo: https://huggingface.co/spaces/editing-images/ledits/tree/main
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
LOW_RESOURCE = True
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def invert(x0, pipe, prompt_src="", num_diffusion_steps=100, cfg_scale_src=3.5, eta=1):
|
| 17 |
+
# inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
|
| 18 |
+
# based on the code in https://github.com/inbarhub/DDPM_inversion
|
| 19 |
+
# returns wt, zs, wts:
|
| 20 |
+
# wt - inverted latent
|
| 21 |
+
# wts - intermediate inverted latents
|
| 22 |
+
# zs - noise maps
|
| 23 |
+
pipe.scheduler.set_timesteps(num_diffusion_steps)
|
| 24 |
+
with inference_mode():
|
| 25 |
+
w0 = (pipe.vae.encode(x0).latent_dist.mode() * 0.18215).float()
|
| 26 |
+
wt, zs, wts = inversion_forward_process(pipe, w0, etas=eta, prompt=prompt_src, cfg_scale=cfg_scale_src,
|
| 27 |
+
prog_bar=True, num_inference_steps=num_diffusion_steps)
|
| 28 |
+
return zs, wts
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def inversion_forward_process(model, x0,
|
| 32 |
+
etas=None,
|
| 33 |
+
prog_bar=False,
|
| 34 |
+
prompt="",
|
| 35 |
+
cfg_scale=3.5,
|
| 36 |
+
num_inference_steps=50, eps=None
|
| 37 |
+
):
|
| 38 |
+
if not prompt == "":
|
| 39 |
+
text_embeddings = encode_text(model, prompt)
|
| 40 |
+
uncond_embedding = encode_text(model, "")
|
| 41 |
+
timesteps = model.scheduler.timesteps.to(model.device)
|
| 42 |
+
variance_noise_shape = (
|
| 43 |
+
num_inference_steps,
|
| 44 |
+
model.unet.in_channels,
|
| 45 |
+
model.unet.sample_size,
|
| 46 |
+
model.unet.sample_size)
|
| 47 |
+
if etas is None or (type(etas) in [int, float] and etas == 0):
|
| 48 |
+
eta_is_zero = True
|
| 49 |
+
zs = None
|
| 50 |
+
else:
|
| 51 |
+
eta_is_zero = False
|
| 52 |
+
if type(etas) in [int, float]: etas = [etas] * model.scheduler.num_inference_steps
|
| 53 |
+
xts = sample_xts_from_x0(model, x0, num_inference_steps=num_inference_steps)
|
| 54 |
+
alpha_bar = model.scheduler.alphas_cumprod
|
| 55 |
+
zs = torch.zeros(size=variance_noise_shape, device=model.device)
|
| 56 |
+
|
| 57 |
+
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
|
| 58 |
+
xt = x0
|
| 59 |
+
op = tqdm(reversed(timesteps)) if prog_bar else reversed(timesteps)
|
| 60 |
+
|
| 61 |
+
for t in op:
|
| 62 |
+
idx = t_to_idx[int(t)]
|
| 63 |
+
# 1. predict noise residual
|
| 64 |
+
if not eta_is_zero:
|
| 65 |
+
xt = xts[idx][None]
|
| 66 |
+
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
out = model.unet.forward(xt, timestep=t, encoder_hidden_states=uncond_embedding)
|
| 69 |
+
if not prompt == "":
|
| 70 |
+
cond_out = model.unet.forward(xt, timestep=t, encoder_hidden_states=text_embeddings)
|
| 71 |
+
|
| 72 |
+
if not prompt == "":
|
| 73 |
+
## classifier free guidance
|
| 74 |
+
noise_pred = out.sample + cfg_scale * (cond_out.sample - out.sample)
|
| 75 |
+
else:
|
| 76 |
+
noise_pred = out.sample
|
| 77 |
+
|
| 78 |
+
if eta_is_zero:
|
| 79 |
+
# 2. compute more noisy image and set x_t -> x_t+1
|
| 80 |
+
xt = forward_step(model, noise_pred, t, xt)
|
| 81 |
+
|
| 82 |
+
else:
|
| 83 |
+
xtm1 = xts[idx + 1][None]
|
| 84 |
+
# pred of x0
|
| 85 |
+
pred_original_sample = (xt - (1 - alpha_bar[t]) ** 0.5 * noise_pred) / alpha_bar[t] ** 0.5
|
| 86 |
+
|
| 87 |
+
# direction to xt
|
| 88 |
+
prev_timestep = t - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
|
| 89 |
+
alpha_prod_t_prev = model.scheduler.alphas_cumprod[
|
| 90 |
+
prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
|
| 91 |
+
|
| 92 |
+
variance = get_variance(model, t)
|
| 93 |
+
pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * noise_pred
|
| 94 |
+
|
| 95 |
+
mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
| 96 |
+
|
| 97 |
+
z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5)
|
| 98 |
+
zs[idx] = z
|
| 99 |
+
|
| 100 |
+
# correction to avoid error accumulation
|
| 101 |
+
xtm1 = mu_xt + (etas[idx] * variance ** 0.5) * z
|
| 102 |
+
xts[idx + 1] = xtm1
|
| 103 |
+
|
| 104 |
+
if not zs is None:
|
| 105 |
+
zs[-1] = torch.zeros_like(zs[-1])
|
| 106 |
+
|
| 107 |
+
return xt, zs, xts
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def encode_text(model, prompts):
|
| 111 |
+
text_input = model.tokenizer(
|
| 112 |
+
prompts,
|
| 113 |
+
padding="max_length",
|
| 114 |
+
max_length=model.tokenizer.model_max_length,
|
| 115 |
+
truncation=True,
|
| 116 |
+
return_tensors="pt",
|
| 117 |
+
)
|
| 118 |
+
with torch.no_grad():
|
| 119 |
+
text_encoding = model.text_encoder(text_input.input_ids.to(model.device))[0]
|
| 120 |
+
return text_encoding
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def sample_xts_from_x0(model, x0, num_inference_steps=50):
|
| 124 |
+
"""
|
| 125 |
+
Samples from P(x_1:T|x_0)
|
| 126 |
+
"""
|
| 127 |
+
# torch.manual_seed(43256465436)
|
| 128 |
+
alpha_bar = model.scheduler.alphas_cumprod
|
| 129 |
+
sqrt_one_minus_alpha_bar = (1 - alpha_bar) ** 0.5
|
| 130 |
+
alphas = model.scheduler.alphas
|
| 131 |
+
betas = 1 - alphas
|
| 132 |
+
variance_noise_shape = (
|
| 133 |
+
num_inference_steps,
|
| 134 |
+
model.unet.in_channels,
|
| 135 |
+
model.unet.sample_size,
|
| 136 |
+
model.unet.sample_size)
|
| 137 |
+
|
| 138 |
+
timesteps = model.scheduler.timesteps.to(model.device)
|
| 139 |
+
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
|
| 140 |
+
xts = torch.zeros(variance_noise_shape).to(x0.device)
|
| 141 |
+
for t in reversed(timesteps):
|
| 142 |
+
idx = t_to_idx[int(t)]
|
| 143 |
+
xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t]
|
| 144 |
+
xts = torch.cat([xts, x0], dim=0)
|
| 145 |
+
|
| 146 |
+
return xts
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def forward_step(model, model_output, timestep, sample):
|
| 150 |
+
next_timestep = min(model.scheduler.config.num_train_timesteps - 2,
|
| 151 |
+
timestep + model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps)
|
| 152 |
+
|
| 153 |
+
# 2. compute alphas, betas
|
| 154 |
+
alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
|
| 155 |
+
|
| 156 |
+
beta_prod_t = 1 - alpha_prod_t
|
| 157 |
+
|
| 158 |
+
# 3. compute predicted original sample from predicted noise also called
|
| 159 |
+
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
| 160 |
+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
| 161 |
+
next_sample = model.scheduler.add_noise(pred_original_sample,
|
| 162 |
+
model_output,
|
| 163 |
+
torch.LongTensor([next_timestep]))
|
| 164 |
+
return next_sample
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def get_variance(model, timestep):
|
| 168 |
+
prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
|
| 169 |
+
alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
|
| 170 |
+
alpha_prod_t_prev = model.scheduler.alphas_cumprod[
|
| 171 |
+
prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
|
| 172 |
+
beta_prod_t = 1 - alpha_prod_t
|
| 173 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
| 174 |
+
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
| 175 |
+
return variance
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class AttentionControl(abc.ABC):
|
| 179 |
+
|
| 180 |
+
def step_callback(self, x_t):
|
| 181 |
+
return x_t
|
| 182 |
+
|
| 183 |
+
def between_steps(self):
|
| 184 |
+
return
|
| 185 |
+
|
| 186 |
+
@property
|
| 187 |
+
def num_uncond_att_layers(self):
|
| 188 |
+
return self.num_att_layers if LOW_RESOURCE else 0
|
| 189 |
+
|
| 190 |
+
@abc.abstractmethod
|
| 191 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
| 192 |
+
raise NotImplementedError
|
| 193 |
+
|
| 194 |
+
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
| 195 |
+
if self.cur_att_layer >= self.num_uncond_att_layers:
|
| 196 |
+
if LOW_RESOURCE:
|
| 197 |
+
attn = self.forward(attn, is_cross, place_in_unet)
|
| 198 |
+
else:
|
| 199 |
+
h = attn.shape[0]
|
| 200 |
+
attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
|
| 201 |
+
self.cur_att_layer += 1
|
| 202 |
+
if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
|
| 203 |
+
self.cur_att_layer = 0
|
| 204 |
+
self.cur_step += 1
|
| 205 |
+
self.between_steps()
|
| 206 |
+
return attn
|
| 207 |
+
|
| 208 |
+
def reset(self):
|
| 209 |
+
self.cur_step = 0
|
| 210 |
+
self.cur_att_layer = 0
|
| 211 |
+
|
| 212 |
+
def __init__(self):
|
| 213 |
+
self.cur_step = 0
|
| 214 |
+
self.num_att_layers = -1
|
| 215 |
+
self.cur_att_layer = 0
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class AttentionStore(AttentionControl):
|
| 219 |
+
|
| 220 |
+
@staticmethod
|
| 221 |
+
def get_empty_store():
|
| 222 |
+
return {"down_cross": [], "mid_cross": [], "up_cross": [],
|
| 223 |
+
"down_self": [], "mid_self": [], "up_self": []}
|
| 224 |
+
|
| 225 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
| 226 |
+
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
|
| 227 |
+
if attn.shape[1] <= 32 ** 2: # avoid memory overhead
|
| 228 |
+
self.step_store[key].append(attn)
|
| 229 |
+
return attn
|
| 230 |
+
|
| 231 |
+
def between_steps(self):
|
| 232 |
+
if len(self.attention_store) == 0:
|
| 233 |
+
self.attention_store = self.step_store
|
| 234 |
+
else:
|
| 235 |
+
for key in self.attention_store:
|
| 236 |
+
for i in range(len(self.attention_store[key])):
|
| 237 |
+
self.attention_store[key][i] += self.step_store[key][i]
|
| 238 |
+
self.step_store = self.get_empty_store()
|
| 239 |
+
|
| 240 |
+
def get_average_attention(self):
|
| 241 |
+
average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in
|
| 242 |
+
self.attention_store}
|
| 243 |
+
return average_attention
|
| 244 |
+
|
| 245 |
+
def reset(self):
|
| 246 |
+
super(AttentionStore, self).reset()
|
| 247 |
+
self.step_store = self.get_empty_store()
|
| 248 |
+
self.attention_store = {}
|
| 249 |
+
|
| 250 |
+
def __init__(self):
|
| 251 |
+
super(AttentionStore, self).__init__()
|
| 252 |
+
self.step_store = self.get_empty_store()
|
| 253 |
+
self.attention_store = {}
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def register_attention_control(model, controller):
|
| 257 |
+
def ca_forward(self, place_in_unet):
|
| 258 |
+
to_out = self.to_out
|
| 259 |
+
if type(to_out) is torch.nn.modules.container.ModuleList:
|
| 260 |
+
to_out = self.to_out[0]
|
| 261 |
+
else:
|
| 262 |
+
to_out = self.to_out
|
| 263 |
+
|
| 264 |
+
def forward(x, context=None, mask=None):
|
| 265 |
+
batch_size, sequence_length, dim = x.shape
|
| 266 |
+
h = self.heads
|
| 267 |
+
q = self.to_q(x)
|
| 268 |
+
is_cross = context is not None
|
| 269 |
+
context = context if is_cross else x
|
| 270 |
+
k = self.to_k(context)
|
| 271 |
+
v = self.to_v(context)
|
| 272 |
+
q = self.reshape_heads_to_batch_dim(q)
|
| 273 |
+
k = self.reshape_heads_to_batch_dim(k)
|
| 274 |
+
v = self.reshape_heads_to_batch_dim(v)
|
| 275 |
+
|
| 276 |
+
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
|
| 277 |
+
|
| 278 |
+
if mask is not None:
|
| 279 |
+
mask = mask.reshape(batch_size, -1)
|
| 280 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
| 281 |
+
mask = mask[:, None, :].repeat(h, 1, 1)
|
| 282 |
+
sim.masked_fill_(~mask, max_neg_value)
|
| 283 |
+
|
| 284 |
+
# attention, what we cannot get enough of
|
| 285 |
+
attn = sim.softmax(dim=-1)
|
| 286 |
+
attn = controller(attn, is_cross, place_in_unet)
|
| 287 |
+
out = torch.einsum("b i j, b j d -> b i d", attn, v)
|
| 288 |
+
out = self.reshape_batch_dim_to_heads(out)
|
| 289 |
+
return to_out(out)
|
| 290 |
+
|
| 291 |
+
return forward
|
| 292 |
+
|
| 293 |
+
class DummyController:
|
| 294 |
+
|
| 295 |
+
def __call__(self, *args):
|
| 296 |
+
return args[0]
|
| 297 |
+
|
| 298 |
+
def __init__(self):
|
| 299 |
+
self.num_att_layers = 0
|
| 300 |
+
|
| 301 |
+
if controller is None:
|
| 302 |
+
controller = DummyController()
|
| 303 |
+
|
| 304 |
+
def register_recr(net_, count, place_in_unet):
|
| 305 |
+
if net_.__class__.__name__ == 'CrossAttention':
|
| 306 |
+
net_.forward = ca_forward(net_, place_in_unet)
|
| 307 |
+
return count + 1
|
| 308 |
+
elif hasattr(net_, 'children'):
|
| 309 |
+
for net__ in net_.children():
|
| 310 |
+
count = register_recr(net__, count, place_in_unet)
|
| 311 |
+
return count
|
| 312 |
+
|
| 313 |
+
cross_att_count = 0
|
| 314 |
+
sub_nets = model.unet.named_children()
|
| 315 |
+
for net in sub_nets:
|
| 316 |
+
if "down" in net[0]:
|
| 317 |
+
cross_att_count += register_recr(net[1], 0, "down")
|
| 318 |
+
elif "up" in net[0]:
|
| 319 |
+
cross_att_count += register_recr(net[1], 0, "up")
|
| 320 |
+
elif "mid" in net[0]:
|
| 321 |
+
cross_att_count += register_recr(net[1], 0, "mid")
|
| 322 |
+
|
| 323 |
+
controller.num_att_layers = cross_att_count
|
utils/image_utils.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pathlib
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
from config import RunConfig
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def load_images(cfg: RunConfig, save_path: Optional[pathlib.Path] = None) -> Tuple[Image.Image, Image.Image]:
|
| 11 |
+
image_style = load_size(cfg.app_image_path)
|
| 12 |
+
image_struct = load_size(cfg.struct_image_path)
|
| 13 |
+
if save_path is not None:
|
| 14 |
+
Image.fromarray(image_style).save(save_path / f"in_style.png")
|
| 15 |
+
Image.fromarray(image_struct).save(save_path / f"in_struct.png")
|
| 16 |
+
return image_style, image_struct
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def load_size(image_path: pathlib.Path,
|
| 20 |
+
left: int = 0,
|
| 21 |
+
right: int = 0,
|
| 22 |
+
top: int = 0,
|
| 23 |
+
bottom: int = 0,
|
| 24 |
+
size: int = 512) -> Image.Image:
|
| 25 |
+
if type(image_path) is str or type(image_path) is pathlib.PosixPath:
|
| 26 |
+
image = np.array(Image.open(image_path).convert('RGB'))
|
| 27 |
+
else:
|
| 28 |
+
image = image_path
|
| 29 |
+
|
| 30 |
+
h, w, c = image.shape
|
| 31 |
+
|
| 32 |
+
left = min(left, w - 1)
|
| 33 |
+
right = min(right, w - left - 1)
|
| 34 |
+
top = min(top, h - left - 1)
|
| 35 |
+
bottom = min(bottom, h - top - 1)
|
| 36 |
+
image = image[top:h - bottom, left:w - right]
|
| 37 |
+
|
| 38 |
+
h, w, c = image.shape
|
| 39 |
+
|
| 40 |
+
if h < w:
|
| 41 |
+
offset = (w - h) // 2
|
| 42 |
+
image = image[:, offset:offset + h]
|
| 43 |
+
elif w < h:
|
| 44 |
+
offset = (h - w) // 2
|
| 45 |
+
image = image[offset:offset + w]
|
| 46 |
+
|
| 47 |
+
image = np.array(Image.fromarray(image).resize((size, size)))
|
| 48 |
+
return image
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def save_generated_masks(model, cfg: RunConfig):
|
| 52 |
+
tensor2im(model.image_app_mask_32).save(cfg.output_path / f"mask_style_32.png")
|
| 53 |
+
tensor2im(model.image_struct_mask_32).save(cfg.output_path / f"mask_struct_32.png")
|
| 54 |
+
tensor2im(model.image_app_mask_64).save(cfg.output_path / f"mask_style_64.png")
|
| 55 |
+
tensor2im(model.image_struct_mask_64).save(cfg.output_path / f"mask_struct_64.png")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def tensor2im(x) -> Image.Image:
|
| 59 |
+
return Image.fromarray(x.cpu().numpy().astype(np.uint8) * 255)
|
utils/latent_utils.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
from appearance_transfer_model import AppearanceTransferModel
|
| 9 |
+
from config import RunConfig
|
| 10 |
+
from utils import image_utils
|
| 11 |
+
from utils.ddpm_inversion import invert
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_latents_or_invert_images(model: AppearanceTransferModel, cfg: RunConfig):
|
| 15 |
+
if cfg.load_latents and cfg.app_latent_save_path.exists() and cfg.struct_latent_save_path.exists():
|
| 16 |
+
print("Loading existing latents...")
|
| 17 |
+
latents_app, latents_struct = load_latents(cfg.app_latent_save_path, cfg.struct_latent_save_path)
|
| 18 |
+
noise_app, noise_struct = load_noise(cfg.app_latent_save_path, cfg.struct_latent_save_path)
|
| 19 |
+
print("Done.")
|
| 20 |
+
else:
|
| 21 |
+
print("Inverting images...")
|
| 22 |
+
app_image, struct_image = image_utils.load_images(cfg=cfg, save_path=cfg.output_path)
|
| 23 |
+
model.enable_edit = False # Deactivate the cross-image attention layers
|
| 24 |
+
latents_app, latents_struct, noise_app, noise_struct = invert_images(app_image=app_image,
|
| 25 |
+
struct_image=struct_image,
|
| 26 |
+
sd_model=model.pipe,
|
| 27 |
+
cfg=cfg)
|
| 28 |
+
model.enable_edit = True
|
| 29 |
+
print("Done.")
|
| 30 |
+
return latents_app, latents_struct, noise_app, noise_struct
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def load_latents(app_latent_save_path: Path, struct_latent_save_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 34 |
+
latents_app = torch.load(app_latent_save_path)
|
| 35 |
+
latents_struct = torch.load(struct_latent_save_path)
|
| 36 |
+
if type(latents_struct) == list:
|
| 37 |
+
latents_app = [l.to("cuda") for l in latents_app]
|
| 38 |
+
latents_struct = [l.to("cuda") for l in latents_struct]
|
| 39 |
+
else:
|
| 40 |
+
latents_app = latents_app.to("cuda")
|
| 41 |
+
latents_struct = latents_struct.to("cuda")
|
| 42 |
+
return latents_app, latents_struct
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def load_noise(app_latent_save_path: Path, struct_latent_save_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 46 |
+
latents_app = torch.load(app_latent_save_path.parent / (app_latent_save_path.stem + "_ddpm_noise.pt"))
|
| 47 |
+
latents_struct = torch.load(struct_latent_save_path.parent / (struct_latent_save_path.stem + "_ddpm_noise.pt"))
|
| 48 |
+
latents_app = latents_app.to("cuda")
|
| 49 |
+
latents_struct = latents_struct.to("cuda")
|
| 50 |
+
return latents_app, latents_struct
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def invert_images(sd_model: AppearanceTransferModel, app_image: Image.Image, struct_image: Image.Image, cfg: RunConfig):
|
| 54 |
+
input_app = torch.from_numpy(np.array(app_image)).float() / 127.5 - 1.0
|
| 55 |
+
input_struct = torch.from_numpy(np.array(struct_image)).float() / 127.5 - 1.0
|
| 56 |
+
zs_app, latents_app = invert(x0=input_app.permute(2, 0, 1).unsqueeze(0).to('cuda'),
|
| 57 |
+
pipe=sd_model,
|
| 58 |
+
prompt_src=cfg.prompt,
|
| 59 |
+
num_diffusion_steps=cfg.num_timesteps,
|
| 60 |
+
cfg_scale_src=3.5)
|
| 61 |
+
zs_struct, latents_struct = invert(x0=input_struct.permute(2, 0, 1).unsqueeze(0).to('cuda'),
|
| 62 |
+
pipe=sd_model,
|
| 63 |
+
prompt_src=cfg.prompt,
|
| 64 |
+
num_diffusion_steps=cfg.num_timesteps,
|
| 65 |
+
cfg_scale_src=3.5)
|
| 66 |
+
# Save the inverted latents and noises
|
| 67 |
+
torch.save(latents_app, cfg.latents_path / f"{cfg.app_image_path.stem}.pt")
|
| 68 |
+
torch.save(latents_struct, cfg.latents_path / f"{cfg.struct_image_path.stem}.pt")
|
| 69 |
+
torch.save(zs_app, cfg.latents_path / f"{cfg.app_image_path.stem}_ddpm_noise.pt")
|
| 70 |
+
torch.save(zs_struct, cfg.latents_path / f"{cfg.struct_image_path.stem}_ddpm_noise.pt")
|
| 71 |
+
return latents_app, latents_struct, zs_app, zs_struct
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_init_latents_and_noises(model: AppearanceTransferModel, cfg: RunConfig) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 75 |
+
# If we stored all the latents along the diffusion process, select the desired one based on the skip_steps
|
| 76 |
+
if model.latents_struct.dim() == 4 and model.latents_app.dim() == 4 and model.latents_app.shape[0] > 1:
|
| 77 |
+
model.latents_struct = model.latents_struct[cfg.skip_steps]
|
| 78 |
+
model.latents_app = model.latents_app[cfg.skip_steps]
|
| 79 |
+
init_latents = torch.stack([model.latents_struct, model.latents_app, model.latents_struct])
|
| 80 |
+
init_zs = [model.zs_struct[cfg.skip_steps:], model.zs_app[cfg.skip_steps:], model.zs_struct[cfg.skip_steps:]]
|
| 81 |
+
return init_latents, init_zs
|
utils/model_utils.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from diffusers import DDIMScheduler
|
| 3 |
+
|
| 4 |
+
from models.stable_diffusion import CrossImageAttentionStableDiffusionPipeline
|
| 5 |
+
from models.unet_2d_condition import FreeUUNet2DConditionModel
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_stable_diffusion_model() -> CrossImageAttentionStableDiffusionPipeline:
|
| 9 |
+
print("Loading Stable Diffusion model...")
|
| 10 |
+
device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
|
| 11 |
+
pipe = CrossImageAttentionStableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",
|
| 12 |
+
safety_checker=None).to(device)
|
| 13 |
+
pipe.unet = FreeUUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet").to(device)
|
| 14 |
+
pipe.scheduler = DDIMScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
|
| 15 |
+
print("Done.")
|
| 16 |
+
return pipe
|
utils/segmentation.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, List
|
| 2 |
+
|
| 3 |
+
import nltk
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from sklearn.cluster import KMeans
|
| 7 |
+
|
| 8 |
+
from constants import STYLE_INDEX, STRUCT_INDEX
|
| 9 |
+
|
| 10 |
+
nltk.download('punkt')
|
| 11 |
+
nltk.download('averaged_perceptron_tagger')
|
| 12 |
+
|
| 13 |
+
"""
|
| 14 |
+
Self-segmentation technique taken from Prompt Mixing: https://github.com/orpatashnik/local-prompt-mixing
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
class Segmentor:
|
| 18 |
+
|
| 19 |
+
def __init__(self, prompt: str, object_nouns: List[str], num_segments: int = 5, res: int = 32):
|
| 20 |
+
self.prompt = prompt
|
| 21 |
+
self.num_segments = num_segments
|
| 22 |
+
self.resolution = res
|
| 23 |
+
self.object_nouns = object_nouns
|
| 24 |
+
tokenized_prompt = nltk.word_tokenize(prompt)
|
| 25 |
+
forbidden_words = [word.upper() for word in ["photo", "image", "picture"]]
|
| 26 |
+
self.nouns = [(i, word) for (i, (word, pos)) in enumerate(nltk.pos_tag(tokenized_prompt))
|
| 27 |
+
if pos[:2] == 'NN' and word.upper() not in forbidden_words]
|
| 28 |
+
|
| 29 |
+
def update_attention(self, attn, is_cross):
|
| 30 |
+
res = int(attn.shape[2] ** 0.5)
|
| 31 |
+
if is_cross:
|
| 32 |
+
if res == 16:
|
| 33 |
+
self.cross_attention_32 = attn
|
| 34 |
+
elif res == 32:
|
| 35 |
+
self.cross_attention_64 = attn
|
| 36 |
+
else:
|
| 37 |
+
if res == 32:
|
| 38 |
+
self.self_attention_32 = attn
|
| 39 |
+
elif res == 64:
|
| 40 |
+
self.self_attention_64 = attn
|
| 41 |
+
|
| 42 |
+
def __call__(self, *args, **kwargs):
|
| 43 |
+
clusters = self.cluster()
|
| 44 |
+
cluster2noun = self.cluster2noun(clusters)
|
| 45 |
+
return cluster2noun
|
| 46 |
+
|
| 47 |
+
def cluster(self, res: int = 32):
|
| 48 |
+
np.random.seed(1)
|
| 49 |
+
self_attn = self.self_attention_32 if res == 32 else self.self_attention_64
|
| 50 |
+
|
| 51 |
+
style_attn = self_attn[STYLE_INDEX].mean(dim=0).cpu().numpy()
|
| 52 |
+
style_kmeans = KMeans(n_clusters=self.num_segments, n_init=10).fit(style_attn)
|
| 53 |
+
style_clusters = style_kmeans.labels_.reshape(res, res)
|
| 54 |
+
|
| 55 |
+
struct_attn = self_attn[STRUCT_INDEX].mean(dim=0).cpu().numpy()
|
| 56 |
+
struct_kmeans = KMeans(n_clusters=self.num_segments, n_init=10).fit(struct_attn)
|
| 57 |
+
struct_clusters = struct_kmeans.labels_.reshape(res, res)
|
| 58 |
+
|
| 59 |
+
return style_clusters, struct_clusters
|
| 60 |
+
|
| 61 |
+
def cluster2noun(self, clusters, cross_attn, attn_index):
|
| 62 |
+
result = {}
|
| 63 |
+
res = int(cross_attn.shape[2] ** 0.5)
|
| 64 |
+
nouns_indices = [index for (index, word) in self.nouns]
|
| 65 |
+
cross_attn = cross_attn[attn_index].mean(dim=0).reshape(res, res, -1)
|
| 66 |
+
nouns_maps = cross_attn.cpu().numpy()[:, :, [i + 1 for i in nouns_indices]]
|
| 67 |
+
normalized_nouns_maps = np.zeros_like(nouns_maps).repeat(2, axis=0).repeat(2, axis=1)
|
| 68 |
+
for i in range(nouns_maps.shape[-1]):
|
| 69 |
+
curr_noun_map = nouns_maps[:, :, i].repeat(2, axis=0).repeat(2, axis=1)
|
| 70 |
+
normalized_nouns_maps[:, :, i] = (curr_noun_map - np.abs(curr_noun_map.min())) / curr_noun_map.max()
|
| 71 |
+
|
| 72 |
+
max_score = 0
|
| 73 |
+
all_scores = []
|
| 74 |
+
for c in range(self.num_segments):
|
| 75 |
+
cluster_mask = np.zeros_like(clusters)
|
| 76 |
+
cluster_mask[clusters == c] = 1
|
| 77 |
+
score_maps = [cluster_mask * normalized_nouns_maps[:, :, i] for i in range(len(nouns_indices))]
|
| 78 |
+
scores = [score_map.sum() / cluster_mask.sum() for score_map in score_maps]
|
| 79 |
+
all_scores.append(max(scores))
|
| 80 |
+
max_score = max(max(scores), max_score)
|
| 81 |
+
|
| 82 |
+
all_scores.remove(max_score)
|
| 83 |
+
mean_score = sum(all_scores) / len(all_scores)
|
| 84 |
+
|
| 85 |
+
for c in range(self.num_segments):
|
| 86 |
+
cluster_mask = np.zeros_like(clusters)
|
| 87 |
+
cluster_mask[clusters == c] = 1
|
| 88 |
+
score_maps = [cluster_mask * normalized_nouns_maps[:, :, i] for i in range(len(nouns_indices))]
|
| 89 |
+
scores = [score_map.sum() / cluster_mask.sum() for score_map in score_maps]
|
| 90 |
+
result[c] = self.nouns[np.argmax(np.array(scores))] if max(scores) > 1.4 * mean_score else "BG"
|
| 91 |
+
|
| 92 |
+
return result
|
| 93 |
+
|
| 94 |
+
def create_mask(self, clusters, cross_attention, attn_index):
|
| 95 |
+
cluster2noun = self.cluster2noun(clusters, cross_attention, attn_index)
|
| 96 |
+
mask = clusters.copy()
|
| 97 |
+
obj_segments = [c for c in cluster2noun if cluster2noun[c][1] in self.object_nouns]
|
| 98 |
+
for c in range(self.num_segments):
|
| 99 |
+
mask[clusters == c] = 1 if c in obj_segments else 0
|
| 100 |
+
return torch.from_numpy(mask).to("cuda")
|
| 101 |
+
|
| 102 |
+
def get_object_masks(self) -> Tuple[torch.Tensor]:
|
| 103 |
+
clusters_style_32, clusters_struct_32 = self.cluster(res=32)
|
| 104 |
+
clusters_style_64, clusters_struct_64 = self.cluster(res=64)
|
| 105 |
+
|
| 106 |
+
mask_style_32 = self.create_mask(clusters_style_32, self.cross_attention_32, STYLE_INDEX)
|
| 107 |
+
mask_struct_32 = self.create_mask(clusters_struct_32, self.cross_attention_32, STRUCT_INDEX)
|
| 108 |
+
mask_style_64 = self.create_mask(clusters_style_64, self.cross_attention_64, STYLE_INDEX)
|
| 109 |
+
mask_struct_64 = self.create_mask(clusters_struct_64, self.cross_attention_64, STRUCT_INDEX)
|
| 110 |
+
|
| 111 |
+
return mask_style_32, mask_struct_32, mask_style_64, mask_struct_64
|