Spaces:
Runtime error
Runtime error
modules
#1
by
MoinulwithAI
- opened
- .gitattributes +0 -10
- __init__.cpython-310.pyc +0 -0
- __init__.py +0 -0
- app.py +7 -5
- attention.cpython-310.pyc +0 -0
- attention.py +0 -133
- autoencoder.cpython-310.pyc +0 -0
- autoencoder.py +0 -326
- conditioner.cpython-310.pyc +0 -0
- conditioner.py +0 -216
- connector_edit.cpython-310.pyc +0 -0
- connector_edit.py +0 -486
- cookie.png → examples 2.zip +2 -2
- examples 2/celeb_meme.jpg +0 -3
- examples 2/cookie.png +0 -3
- examples 2/ghibli_meme.jpg +0 -0
- examples 2/leather.jpg +0 -3
- examples 2/meme.jpg +0 -0
- examples 2/no_cookie.png +0 -3
- examples 2/poster.jpg +0 -0
- examples 2/poster_orig.jpg +0 -3
- ghibli_meme.jpg +0 -0
- layers.cpython-310.pyc +0 -0
- layers.py +0 -640
- leather.jpg +0 -3
- meme.jpg +0 -0
- model_edit.cpython-310.pyc +0 -0
- model_edit.py +0 -143
- celeb_meme.jpg → modules.zip +2 -2
- modules/__init__.py +0 -0
- modules/__pycache__/__init__.cpython-310.pyc +0 -0
- modules/__pycache__/attention.cpython-310.pyc +0 -0
- modules/__pycache__/autoencoder.cpython-310.pyc +0 -0
- modules/__pycache__/conditioner.cpython-310.pyc +0 -0
- modules/__pycache__/connector_edit.cpython-310.pyc +0 -0
- modules/__pycache__/layers.cpython-310.pyc +0 -0
- modules/__pycache__/model_edit.cpython-310.pyc +0 -0
- modules/attention.py +0 -133
- modules/autoencoder.py +0 -326
- modules/conditioner.py +0 -216
- modules/connector_edit.py +0 -486
- modules/layers.py +0 -640
- modules/model_edit.py +0 -143
- no_cookie.png +0 -3
- poster.jpg +0 -0
- poster_orig.jpg +0 -3
.gitattributes
CHANGED
@@ -33,13 +33,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
-
celeb_meme.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
-
cookie.png filter=lfs diff=lfs merge=lfs -text
|
38 |
-
leather.jpg filter=lfs diff=lfs merge=lfs -text
|
39 |
-
no_cookie.png filter=lfs diff=lfs merge=lfs -text
|
40 |
-
poster_orig.jpg filter=lfs diff=lfs merge=lfs -text
|
41 |
-
examples[[:space:]]2/celeb_meme.jpg filter=lfs diff=lfs merge=lfs -text
|
42 |
-
examples[[:space:]]2/cookie.png filter=lfs diff=lfs merge=lfs -text
|
43 |
-
examples[[:space:]]2/leather.jpg filter=lfs diff=lfs merge=lfs -text
|
44 |
-
examples[[:space:]]2/no_cookie.png filter=lfs diff=lfs merge=lfs -text
|
45 |
-
examples[[:space:]]2/poster_orig.jpg filter=lfs diff=lfs merge=lfs -text
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__init__.cpython-310.pyc
DELETED
Binary file (128 Bytes)
|
|
__init__.py
DELETED
File without changes
|
app.py
CHANGED
@@ -8,6 +8,7 @@ import spaces
|
|
8 |
import time
|
9 |
from pathlib import Path
|
10 |
|
|
|
11 |
import gradio as gr
|
12 |
import numpy as np
|
13 |
import torch
|
@@ -60,7 +61,7 @@ def load_models(
|
|
60 |
dit_path=None,
|
61 |
ae_path=None,
|
62 |
qwen2vl_model_path=None,
|
63 |
-
device="
|
64 |
max_length=256,
|
65 |
dtype=torch.bfloat16,
|
66 |
):
|
@@ -117,7 +118,7 @@ class ImageGenerator:
|
|
117 |
dit_path=None,
|
118 |
ae_path=None,
|
119 |
qwen2vl_model_path=None,
|
120 |
-
device="
|
121 |
max_length=640,
|
122 |
dtype=torch.bfloat16,
|
123 |
) -> None:
|
@@ -134,9 +135,9 @@ class ImageGenerator:
|
|
134 |
self.llm_encoder = self.llm_encoder.to(device=self.device, dtype=dtype)
|
135 |
|
136 |
def to_cuda(self):
|
137 |
-
self.ae.to(device='
|
138 |
-
self.dit.to(device='
|
139 |
-
self.llm_encoder.to(device='
|
140 |
|
141 |
def prepare(self, prompt, img, ref_image, ref_image_raw):
|
142 |
bs, _, h, w = img.shape
|
@@ -487,4 +488,5 @@ with gr.Blocks() as demo:
|
|
487 |
fn=generate_examples,
|
488 |
cache_examples=True
|
489 |
)
|
|
|
490 |
demo.launch()
|
|
|
8 |
import time
|
9 |
from pathlib import Path
|
10 |
|
11 |
+
|
12 |
import gradio as gr
|
13 |
import numpy as np
|
14 |
import torch
|
|
|
61 |
dit_path=None,
|
62 |
ae_path=None,
|
63 |
qwen2vl_model_path=None,
|
64 |
+
device="cuda",
|
65 |
max_length=256,
|
66 |
dtype=torch.bfloat16,
|
67 |
):
|
|
|
118 |
dit_path=None,
|
119 |
ae_path=None,
|
120 |
qwen2vl_model_path=None,
|
121 |
+
device="cuda",
|
122 |
max_length=640,
|
123 |
dtype=torch.bfloat16,
|
124 |
) -> None:
|
|
|
135 |
self.llm_encoder = self.llm_encoder.to(device=self.device, dtype=dtype)
|
136 |
|
137 |
def to_cuda(self):
|
138 |
+
self.ae.to(device='cuda', dtype=torch.float32)
|
139 |
+
self.dit.to(device='cuda', dtype=torch.bfloat16)
|
140 |
+
self.llm_encoder.to(device='cuda', dtype=torch.bfloat16)
|
141 |
|
142 |
def prepare(self, prompt, img, ref_image, ref_image_raw):
|
143 |
bs, _, h, w = img.shape
|
|
|
488 |
fn=generate_examples,
|
489 |
cache_examples=True
|
490 |
)
|
491 |
+
|
492 |
demo.launch()
|
attention.cpython-310.pyc
DELETED
Binary file (3.13 kB)
|
|
attention.py
DELETED
@@ -1,133 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import torch.nn.functional as F
|
5 |
-
|
6 |
-
|
7 |
-
try:
|
8 |
-
import flash_attn
|
9 |
-
from flash_attn.flash_attn_interface import (
|
10 |
-
_flash_attn_forward,
|
11 |
-
flash_attn_func,
|
12 |
-
flash_attn_varlen_func,
|
13 |
-
)
|
14 |
-
except ImportError:
|
15 |
-
flash_attn = None
|
16 |
-
flash_attn_varlen_func = None
|
17 |
-
_flash_attn_forward = None
|
18 |
-
flash_attn_func = None
|
19 |
-
|
20 |
-
MEMORY_LAYOUT = {
|
21 |
-
# flash模式:
|
22 |
-
# 预处理: 输入 [batch_size, seq_len, num_heads, head_dim]
|
23 |
-
# 后处理: 保持形状不变
|
24 |
-
"flash": (
|
25 |
-
lambda x: x, # 保持形状
|
26 |
-
lambda x: x, # 保持形状
|
27 |
-
),
|
28 |
-
# torch/vanilla模式:
|
29 |
-
# 预处理: 交换序列和注意力头的维度 [B,S,A,D] -> [B,A,S,D]
|
30 |
-
# 后处理: 交换回原始维度 [B,A,S,D] -> [B,S,A,D]
|
31 |
-
"torch": (
|
32 |
-
lambda x: x.transpose(1, 2), # (B,S,A,D) -> (B,A,S,D)
|
33 |
-
lambda x: x.transpose(1, 2), # (B,A,S,D) -> (B,S,A,D)
|
34 |
-
),
|
35 |
-
"vanilla": (
|
36 |
-
lambda x: x.transpose(1, 2),
|
37 |
-
lambda x: x.transpose(1, 2),
|
38 |
-
),
|
39 |
-
}
|
40 |
-
|
41 |
-
|
42 |
-
def attention(
|
43 |
-
q,
|
44 |
-
k,
|
45 |
-
v,
|
46 |
-
mode="torch",
|
47 |
-
drop_rate=0,
|
48 |
-
attn_mask=None,
|
49 |
-
causal=False,
|
50 |
-
):
|
51 |
-
"""
|
52 |
-
执行QKV自注意力计算
|
53 |
-
|
54 |
-
Args:
|
55 |
-
q (torch.Tensor): 查询张量,形状 [batch_size, seq_len, num_heads, head_dim]
|
56 |
-
k (torch.Tensor): 键张量,形状 [batch_size, seq_len_kv, num_heads, head_dim]
|
57 |
-
v (torch.Tensor): 值张量,形状 [batch_size, seq_len_kv, num_heads, head_dim]
|
58 |
-
mode (str): 注意力模式,可选 'flash', 'torch', 'vanilla'
|
59 |
-
drop_rate (float): 注意力矩阵的dropout概率
|
60 |
-
attn_mask (torch.Tensor): 注意力掩码,形状根据模式不同而变化
|
61 |
-
causal (bool): 是否使用因果注意力(仅关注前面位置)
|
62 |
-
|
63 |
-
Returns:
|
64 |
-
torch.Tensor: 注意力输出,形状 [batch_size, seq_len, num_heads * head_dim]
|
65 |
-
"""
|
66 |
-
# 获取预处理和后处理函数
|
67 |
-
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
|
68 |
-
|
69 |
-
# 应用预处理变换
|
70 |
-
q = pre_attn_layout(q) # 形状根据模式变化
|
71 |
-
k = pre_attn_layout(k)
|
72 |
-
v = pre_attn_layout(v)
|
73 |
-
|
74 |
-
if mode == "torch":
|
75 |
-
# 使用PyTorch原生的scaled_dot_product_attention
|
76 |
-
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
77 |
-
attn_mask = attn_mask.to(q.dtype)
|
78 |
-
x = F.scaled_dot_product_attention(
|
79 |
-
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
|
80 |
-
)
|
81 |
-
elif mode == "flash":
|
82 |
-
assert flash_attn_func is not None, "flash_attn_func未定义"
|
83 |
-
assert attn_mask is None, "不支持的注意力掩码"
|
84 |
-
x: torch.Tensor = flash_attn_func(
|
85 |
-
q, k, v, dropout_p=drop_rate, causal=causal, softmax_scale=None
|
86 |
-
) # type: ignore
|
87 |
-
elif mode == "vanilla":
|
88 |
-
# 手动实现注意力机制
|
89 |
-
scale_factor = 1 / math.sqrt(q.size(-1)) # 缩放因子 1/sqrt(d_k)
|
90 |
-
|
91 |
-
b, a, s, _ = q.shape # 获取形状参数
|
92 |
-
s1 = k.size(2) # 键值序列长度
|
93 |
-
|
94 |
-
# 初始化注意力偏置
|
95 |
-
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
|
96 |
-
|
97 |
-
# 处理因果掩码
|
98 |
-
if causal:
|
99 |
-
assert attn_mask is None, "因果掩码和注意力掩码不能同时使用"
|
100 |
-
# 生成下三角因果掩码
|
101 |
-
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
|
102 |
-
diagonal=0
|
103 |
-
)
|
104 |
-
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
105 |
-
attn_bias = attn_bias.to(q.dtype)
|
106 |
-
|
107 |
-
# 处理自定义注意力掩码
|
108 |
-
if attn_mask is not None:
|
109 |
-
if attn_mask.dtype == torch.bool:
|
110 |
-
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
111 |
-
else:
|
112 |
-
attn_bias += attn_mask # 允许类似ALiBi的位置偏置
|
113 |
-
|
114 |
-
# 计算注意力矩阵
|
115 |
-
attn = (q @ k.transpose(-2, -1)) * scale_factor # [B,A,S,S1]
|
116 |
-
attn += attn_bias
|
117 |
-
|
118 |
-
# softmax和dropout
|
119 |
-
attn = attn.softmax(dim=-1)
|
120 |
-
attn = torch.dropout(attn, p=drop_rate, train=True)
|
121 |
-
|
122 |
-
# 计算输出
|
123 |
-
x = attn @ v # [B,A,S,D]
|
124 |
-
else:
|
125 |
-
raise NotImplementedError(f"不支持的注意力模式: {mode}")
|
126 |
-
|
127 |
-
# 应用后处理变换
|
128 |
-
x = post_attn_layout(x) # 恢复原始维度顺序
|
129 |
-
|
130 |
-
# 合并注意力头维度
|
131 |
-
b, s, a, d = x.shape
|
132 |
-
out = x.reshape(b, s, -1) # [B,S,A*D]
|
133 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
autoencoder.cpython-310.pyc
DELETED
Binary file (8.78 kB)
|
|
autoencoder.py
DELETED
@@ -1,326 +0,0 @@
|
|
1 |
-
# Modified from Flux
|
2 |
-
#
|
3 |
-
# Copyright 2024 Black Forest Labs
|
4 |
-
|
5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
-
# you may not use this file except in compliance with the License.
|
7 |
-
# You may obtain a copy of the License at
|
8 |
-
|
9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
-
|
11 |
-
# Unless required by applicable law or agreed to in writing, software
|
12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
-
# See the License for the specific language governing permissions and
|
15 |
-
# limitations under the License.
|
16 |
-
#
|
17 |
-
# This source code is licensed under the license found in the
|
18 |
-
# LICENSE file in the root directory of this source tree.
|
19 |
-
import torch
|
20 |
-
from einops import rearrange
|
21 |
-
from torch import Tensor, nn
|
22 |
-
|
23 |
-
|
24 |
-
def swish(x: Tensor) -> Tensor:
|
25 |
-
return x * torch.sigmoid(x)
|
26 |
-
|
27 |
-
|
28 |
-
class AttnBlock(nn.Module):
|
29 |
-
def __init__(self, in_channels: int):
|
30 |
-
super().__init__()
|
31 |
-
self.in_channels = in_channels
|
32 |
-
|
33 |
-
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
34 |
-
|
35 |
-
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
36 |
-
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
37 |
-
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
38 |
-
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
39 |
-
|
40 |
-
def attention(self, h_: Tensor) -> Tensor:
|
41 |
-
h_ = self.norm(h_)
|
42 |
-
q = self.q(h_)
|
43 |
-
k = self.k(h_)
|
44 |
-
v = self.v(h_)
|
45 |
-
|
46 |
-
b, c, h, w = q.shape
|
47 |
-
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
48 |
-
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
49 |
-
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
50 |
-
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
51 |
-
|
52 |
-
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
53 |
-
|
54 |
-
def forward(self, x: Tensor) -> Tensor:
|
55 |
-
return x + self.proj_out(self.attention(x))
|
56 |
-
|
57 |
-
|
58 |
-
class ResnetBlock(nn.Module):
|
59 |
-
def __init__(self, in_channels: int, out_channels: int):
|
60 |
-
super().__init__()
|
61 |
-
self.in_channels = in_channels
|
62 |
-
out_channels = in_channels if out_channels is None else out_channels
|
63 |
-
self.out_channels = out_channels
|
64 |
-
|
65 |
-
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
66 |
-
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
67 |
-
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
68 |
-
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
69 |
-
if self.in_channels != self.out_channels:
|
70 |
-
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
71 |
-
|
72 |
-
def forward(self, x):
|
73 |
-
h = x
|
74 |
-
h = self.norm1(h)
|
75 |
-
h = swish(h)
|
76 |
-
h = self.conv1(h)
|
77 |
-
|
78 |
-
h = self.norm2(h)
|
79 |
-
h = swish(h)
|
80 |
-
h = self.conv2(h)
|
81 |
-
|
82 |
-
if self.in_channels != self.out_channels:
|
83 |
-
x = self.nin_shortcut(x)
|
84 |
-
|
85 |
-
return x + h
|
86 |
-
|
87 |
-
|
88 |
-
class Downsample(nn.Module):
|
89 |
-
def __init__(self, in_channels: int):
|
90 |
-
super().__init__()
|
91 |
-
# no asymmetric padding in torch conv, must do it ourselves
|
92 |
-
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
93 |
-
|
94 |
-
def forward(self, x: Tensor):
|
95 |
-
pad = (0, 1, 0, 1)
|
96 |
-
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
97 |
-
x = self.conv(x)
|
98 |
-
return x
|
99 |
-
|
100 |
-
|
101 |
-
class Upsample(nn.Module):
|
102 |
-
def __init__(self, in_channels: int):
|
103 |
-
super().__init__()
|
104 |
-
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
105 |
-
|
106 |
-
def forward(self, x: Tensor):
|
107 |
-
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
108 |
-
x = self.conv(x)
|
109 |
-
return x
|
110 |
-
|
111 |
-
|
112 |
-
class Encoder(nn.Module):
|
113 |
-
def __init__(
|
114 |
-
self,
|
115 |
-
resolution: int,
|
116 |
-
in_channels: int,
|
117 |
-
ch: int,
|
118 |
-
ch_mult: list[int],
|
119 |
-
num_res_blocks: int,
|
120 |
-
z_channels: int,
|
121 |
-
):
|
122 |
-
super().__init__()
|
123 |
-
self.ch = ch
|
124 |
-
self.num_resolutions = len(ch_mult)
|
125 |
-
self.num_res_blocks = num_res_blocks
|
126 |
-
self.resolution = resolution
|
127 |
-
self.in_channels = in_channels
|
128 |
-
# downsampling
|
129 |
-
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
130 |
-
|
131 |
-
curr_res = resolution
|
132 |
-
in_ch_mult = (1, *tuple(ch_mult))
|
133 |
-
self.in_ch_mult = in_ch_mult
|
134 |
-
self.down = nn.ModuleList()
|
135 |
-
block_in = self.ch
|
136 |
-
for i_level in range(self.num_resolutions):
|
137 |
-
block = nn.ModuleList()
|
138 |
-
attn = nn.ModuleList()
|
139 |
-
block_in = ch * in_ch_mult[i_level]
|
140 |
-
block_out = ch * ch_mult[i_level]
|
141 |
-
for _ in range(self.num_res_blocks):
|
142 |
-
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
143 |
-
block_in = block_out
|
144 |
-
down = nn.Module()
|
145 |
-
down.block = block
|
146 |
-
down.attn = attn
|
147 |
-
if i_level != self.num_resolutions - 1:
|
148 |
-
down.downsample = Downsample(block_in)
|
149 |
-
curr_res = curr_res // 2
|
150 |
-
self.down.append(down)
|
151 |
-
|
152 |
-
# middle
|
153 |
-
self.mid = nn.Module()
|
154 |
-
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
155 |
-
self.mid.attn_1 = AttnBlock(block_in)
|
156 |
-
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
157 |
-
|
158 |
-
# end
|
159 |
-
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
160 |
-
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
161 |
-
|
162 |
-
def forward(self, x: Tensor) -> Tensor:
|
163 |
-
# downsampling
|
164 |
-
hs = [self.conv_in(x)]
|
165 |
-
for i_level in range(self.num_resolutions):
|
166 |
-
for i_block in range(self.num_res_blocks):
|
167 |
-
h = self.down[i_level].block[i_block](hs[-1])
|
168 |
-
if len(self.down[i_level].attn) > 0:
|
169 |
-
h = self.down[i_level].attn[i_block](h)
|
170 |
-
hs.append(h)
|
171 |
-
if i_level != self.num_resolutions - 1:
|
172 |
-
hs.append(self.down[i_level].downsample(hs[-1]))
|
173 |
-
|
174 |
-
# middle
|
175 |
-
h = hs[-1]
|
176 |
-
h = self.mid.block_1(h)
|
177 |
-
h = self.mid.attn_1(h)
|
178 |
-
h = self.mid.block_2(h)
|
179 |
-
# end
|
180 |
-
h = self.norm_out(h)
|
181 |
-
h = swish(h)
|
182 |
-
h = self.conv_out(h)
|
183 |
-
return h
|
184 |
-
|
185 |
-
|
186 |
-
class Decoder(nn.Module):
|
187 |
-
def __init__(
|
188 |
-
self,
|
189 |
-
ch: int,
|
190 |
-
out_ch: int,
|
191 |
-
ch_mult: list[int],
|
192 |
-
num_res_blocks: int,
|
193 |
-
in_channels: int,
|
194 |
-
resolution: int,
|
195 |
-
z_channels: int,
|
196 |
-
):
|
197 |
-
super().__init__()
|
198 |
-
self.ch = ch
|
199 |
-
self.num_resolutions = len(ch_mult)
|
200 |
-
self.num_res_blocks = num_res_blocks
|
201 |
-
self.resolution = resolution
|
202 |
-
self.in_channels = in_channels
|
203 |
-
self.ffactor = 2 ** (self.num_resolutions - 1)
|
204 |
-
|
205 |
-
# compute in_ch_mult, block_in and curr_res at lowest res
|
206 |
-
block_in = ch * ch_mult[self.num_resolutions - 1]
|
207 |
-
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
208 |
-
self.z_shape = (1, z_channels, curr_res, curr_res)
|
209 |
-
|
210 |
-
# z to block_in
|
211 |
-
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
212 |
-
|
213 |
-
# middle
|
214 |
-
self.mid = nn.Module()
|
215 |
-
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
216 |
-
self.mid.attn_1 = AttnBlock(block_in)
|
217 |
-
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
218 |
-
|
219 |
-
# upsampling
|
220 |
-
self.up = nn.ModuleList()
|
221 |
-
for i_level in reversed(range(self.num_resolutions)):
|
222 |
-
block = nn.ModuleList()
|
223 |
-
attn = nn.ModuleList()
|
224 |
-
block_out = ch * ch_mult[i_level]
|
225 |
-
for _ in range(self.num_res_blocks + 1):
|
226 |
-
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
227 |
-
block_in = block_out
|
228 |
-
up = nn.Module()
|
229 |
-
up.block = block
|
230 |
-
up.attn = attn
|
231 |
-
if i_level != 0:
|
232 |
-
up.upsample = Upsample(block_in)
|
233 |
-
curr_res = curr_res * 2
|
234 |
-
self.up.insert(0, up) # prepend to get consistent order
|
235 |
-
|
236 |
-
# end
|
237 |
-
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
238 |
-
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
239 |
-
|
240 |
-
def forward(self, z: Tensor) -> Tensor:
|
241 |
-
# z to block_in
|
242 |
-
h = self.conv_in(z)
|
243 |
-
|
244 |
-
# middle
|
245 |
-
h = self.mid.block_1(h)
|
246 |
-
h = self.mid.attn_1(h)
|
247 |
-
h = self.mid.block_2(h)
|
248 |
-
|
249 |
-
# upsampling
|
250 |
-
for i_level in reversed(range(self.num_resolutions)):
|
251 |
-
for i_block in range(self.num_res_blocks + 1):
|
252 |
-
h = self.up[i_level].block[i_block](h)
|
253 |
-
if len(self.up[i_level].attn) > 0:
|
254 |
-
h = self.up[i_level].attn[i_block](h)
|
255 |
-
if i_level != 0:
|
256 |
-
h = self.up[i_level].upsample(h)
|
257 |
-
|
258 |
-
# end
|
259 |
-
h = self.norm_out(h)
|
260 |
-
h = swish(h)
|
261 |
-
h = self.conv_out(h)
|
262 |
-
return h
|
263 |
-
|
264 |
-
|
265 |
-
class DiagonalGaussian(nn.Module):
|
266 |
-
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
267 |
-
super().__init__()
|
268 |
-
self.sample = sample
|
269 |
-
self.chunk_dim = chunk_dim
|
270 |
-
|
271 |
-
def forward(self, z: Tensor) -> Tensor:
|
272 |
-
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
273 |
-
if self.sample:
|
274 |
-
std = torch.exp(0.5 * logvar)
|
275 |
-
return mean + std * torch.randn_like(mean)
|
276 |
-
else:
|
277 |
-
return mean
|
278 |
-
|
279 |
-
|
280 |
-
class AutoEncoder(nn.Module):
|
281 |
-
def __init__(
|
282 |
-
self,
|
283 |
-
resolution: int,
|
284 |
-
in_channels: int,
|
285 |
-
ch: int,
|
286 |
-
out_ch: int,
|
287 |
-
ch_mult: list[int],
|
288 |
-
num_res_blocks: int,
|
289 |
-
z_channels: int,
|
290 |
-
scale_factor: float,
|
291 |
-
shift_factor: float,
|
292 |
-
):
|
293 |
-
super().__init__()
|
294 |
-
self.encoder = Encoder(
|
295 |
-
resolution=resolution,
|
296 |
-
in_channels=in_channels,
|
297 |
-
ch=ch,
|
298 |
-
ch_mult=ch_mult,
|
299 |
-
num_res_blocks=num_res_blocks,
|
300 |
-
z_channels=z_channels,
|
301 |
-
)
|
302 |
-
self.decoder = Decoder(
|
303 |
-
resolution=resolution,
|
304 |
-
in_channels=in_channels,
|
305 |
-
ch=ch,
|
306 |
-
out_ch=out_ch,
|
307 |
-
ch_mult=ch_mult,
|
308 |
-
num_res_blocks=num_res_blocks,
|
309 |
-
z_channels=z_channels,
|
310 |
-
)
|
311 |
-
self.reg = DiagonalGaussian()
|
312 |
-
|
313 |
-
self.scale_factor = scale_factor
|
314 |
-
self.shift_factor = shift_factor
|
315 |
-
|
316 |
-
def encode(self, x: Tensor) -> Tensor:
|
317 |
-
z = self.reg(self.encoder(x))
|
318 |
-
z = self.scale_factor * (z - self.shift_factor)
|
319 |
-
return z
|
320 |
-
|
321 |
-
def decode(self, z: Tensor) -> Tensor:
|
322 |
-
z = z / self.scale_factor + self.shift_factor
|
323 |
-
return self.decoder(z)
|
324 |
-
|
325 |
-
def forward(self, x: Tensor) -> Tensor:
|
326 |
-
return self.decode(self.encode(x))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conditioner.cpython-310.pyc
DELETED
Binary file (4.94 kB)
|
|
conditioner.py
DELETED
@@ -1,216 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from qwen_vl_utils import process_vision_info
|
3 |
-
from transformers import (
|
4 |
-
AutoProcessor,
|
5 |
-
Qwen2VLForConditionalGeneration,
|
6 |
-
Qwen2_5_VLForConditionalGeneration,
|
7 |
-
)
|
8 |
-
from torchvision.transforms import ToPILImage
|
9 |
-
|
10 |
-
to_pil = ToPILImage()
|
11 |
-
|
12 |
-
Qwen25VL_7b_PREFIX = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:
|
13 |
-
- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.
|
14 |
-
- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n
|
15 |
-
Here are examples of how to transform or refine prompts:
|
16 |
-
- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.
|
17 |
-
- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n
|
18 |
-
Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:
|
19 |
-
User Prompt:'''
|
20 |
-
|
21 |
-
|
22 |
-
def split_string(s):
|
23 |
-
# 将中文引号替换为英文引号
|
24 |
-
s = s.replace("“", '"').replace("”", '"') # use english quotes
|
25 |
-
result = []
|
26 |
-
# 标记是否在引号内
|
27 |
-
in_quotes = False
|
28 |
-
temp = ""
|
29 |
-
|
30 |
-
# 遍历字符串中的每个字符及其索引
|
31 |
-
for idx, char in enumerate(s):
|
32 |
-
# 如果字符是引号且索引大于 155
|
33 |
-
if char == '"' and idx > 155:
|
34 |
-
# 将引号添加到临时字符串
|
35 |
-
temp += char
|
36 |
-
# 如果不在引号内
|
37 |
-
if not in_quotes:
|
38 |
-
# 将临时字符串添加到结果列表
|
39 |
-
result.append(temp)
|
40 |
-
# 清空临时字符串
|
41 |
-
temp = ""
|
42 |
-
|
43 |
-
# 切换引号状态
|
44 |
-
in_quotes = not in_quotes
|
45 |
-
continue
|
46 |
-
# 如果在引号内
|
47 |
-
if in_quotes:
|
48 |
-
# 如果字符是空格
|
49 |
-
if char.isspace():
|
50 |
-
pass # have space token
|
51 |
-
|
52 |
-
# 将字符用中文引号包裹后添加到结果列表
|
53 |
-
result.append("“" + char + "”")
|
54 |
-
else:
|
55 |
-
# 将字符添加到临时字符串
|
56 |
-
temp += char
|
57 |
-
|
58 |
-
# 如果临时字符串不为空
|
59 |
-
if temp:
|
60 |
-
# 将临时字符串添加到结果列表
|
61 |
-
result.append(temp)
|
62 |
-
|
63 |
-
return result
|
64 |
-
|
65 |
-
|
66 |
-
class Qwen25VL_7b_Embedder(torch.nn.Module):
|
67 |
-
def __init__(self, model_path, max_length=640, dtype=torch.bfloat16, device="cuda"):
|
68 |
-
super(Qwen25VL_7b_Embedder, self).__init__()
|
69 |
-
self.max_length = max_length
|
70 |
-
self.dtype = dtype
|
71 |
-
self.device = device
|
72 |
-
|
73 |
-
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
74 |
-
model_path,
|
75 |
-
torch_dtype=dtype,
|
76 |
-
attn_implementation="eager",
|
77 |
-
).to(torch.cuda.current_device())
|
78 |
-
|
79 |
-
self.model.requires_grad_(False)
|
80 |
-
self.processor = AutoProcessor.from_pretrained(
|
81 |
-
model_path, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28
|
82 |
-
)
|
83 |
-
|
84 |
-
self.prefix = Qwen25VL_7b_PREFIX
|
85 |
-
|
86 |
-
def forward(self, caption, ref_images):
|
87 |
-
text_list = caption
|
88 |
-
embs = torch.zeros(
|
89 |
-
len(text_list),
|
90 |
-
self.max_length,
|
91 |
-
self.model.config.hidden_size,
|
92 |
-
dtype=torch.bfloat16,
|
93 |
-
device=torch.cuda.current_device(),
|
94 |
-
)
|
95 |
-
hidden_states = torch.zeros(
|
96 |
-
len(text_list),
|
97 |
-
self.max_length,
|
98 |
-
self.model.config.hidden_size,
|
99 |
-
dtype=torch.bfloat16,
|
100 |
-
device=torch.cuda.current_device(),
|
101 |
-
)
|
102 |
-
masks = torch.zeros(
|
103 |
-
len(text_list),
|
104 |
-
self.max_length,
|
105 |
-
dtype=torch.long,
|
106 |
-
device=torch.cuda.current_device(),
|
107 |
-
)
|
108 |
-
input_ids_list = []
|
109 |
-
attention_mask_list = []
|
110 |
-
emb_list = []
|
111 |
-
|
112 |
-
def split_string(s):
|
113 |
-
s = s.replace("“", '"').replace("”", '"').replace("'", '''"''') # use english quotes
|
114 |
-
result = []
|
115 |
-
in_quotes = False
|
116 |
-
temp = ""
|
117 |
-
|
118 |
-
for idx,char in enumerate(s):
|
119 |
-
if char == '"' and idx>155:
|
120 |
-
temp += char
|
121 |
-
if not in_quotes:
|
122 |
-
result.append(temp)
|
123 |
-
temp = ""
|
124 |
-
|
125 |
-
in_quotes = not in_quotes
|
126 |
-
continue
|
127 |
-
if in_quotes:
|
128 |
-
if char.isspace():
|
129 |
-
pass # have space token
|
130 |
-
|
131 |
-
result.append("“" + char + "”")
|
132 |
-
else:
|
133 |
-
temp += char
|
134 |
-
|
135 |
-
if temp:
|
136 |
-
result.append(temp)
|
137 |
-
|
138 |
-
return result
|
139 |
-
|
140 |
-
for idx, (txt, imgs) in enumerate(zip(text_list, ref_images)):
|
141 |
-
|
142 |
-
messages = [{"role": "user", "content": []}]
|
143 |
-
|
144 |
-
messages[0]["content"].append({"type": "text", "text": f"{self.prefix}"})
|
145 |
-
|
146 |
-
messages[0]["content"].append({"type": "image", "image": to_pil(imgs)})
|
147 |
-
|
148 |
-
# 再添加 text
|
149 |
-
messages[0]["content"].append({"type": "text", "text": f"{txt}"})
|
150 |
-
|
151 |
-
# Preparation for inference
|
152 |
-
text = self.processor.apply_chat_template(
|
153 |
-
messages, tokenize=False, add_generation_prompt=True, add_vision_id=True
|
154 |
-
)
|
155 |
-
|
156 |
-
image_inputs, video_inputs = process_vision_info(messages)
|
157 |
-
|
158 |
-
inputs = self.processor(
|
159 |
-
text=[text],
|
160 |
-
images=image_inputs,
|
161 |
-
padding=True,
|
162 |
-
return_tensors="pt",
|
163 |
-
)
|
164 |
-
|
165 |
-
old_inputs_ids = inputs.input_ids
|
166 |
-
text_split_list = split_string(text)
|
167 |
-
|
168 |
-
token_list = []
|
169 |
-
for text_each in text_split_list:
|
170 |
-
txt_inputs = self.processor(
|
171 |
-
text=text_each,
|
172 |
-
images=None,
|
173 |
-
videos=None,
|
174 |
-
padding=True,
|
175 |
-
return_tensors="pt",
|
176 |
-
)
|
177 |
-
token_each = txt_inputs.input_ids
|
178 |
-
if token_each[0][0] == 2073 and token_each[0][-1] == 854:
|
179 |
-
token_each = token_each[:, 1:-1]
|
180 |
-
token_list.append(token_each)
|
181 |
-
else:
|
182 |
-
token_list.append(token_each)
|
183 |
-
|
184 |
-
new_txt_ids = torch.cat(token_list, dim=1).to("cuda")
|
185 |
-
|
186 |
-
new_txt_ids = new_txt_ids.to(old_inputs_ids.device)
|
187 |
-
|
188 |
-
idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0]
|
189 |
-
idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0]
|
190 |
-
inputs.input_ids = (
|
191 |
-
torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
|
192 |
-
.unsqueeze(0)
|
193 |
-
.to("cuda")
|
194 |
-
)
|
195 |
-
inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda")
|
196 |
-
outputs = self.model(
|
197 |
-
input_ids=inputs.input_ids,
|
198 |
-
attention_mask=inputs.attention_mask,
|
199 |
-
pixel_values=inputs.pixel_values.to("cuda"),
|
200 |
-
image_grid_thw=inputs.image_grid_thw.to("cuda"),
|
201 |
-
output_hidden_states=True,
|
202 |
-
)
|
203 |
-
|
204 |
-
emb = outputs["hidden_states"][-1]
|
205 |
-
|
206 |
-
embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][
|
207 |
-
: self.max_length
|
208 |
-
]
|
209 |
-
|
210 |
-
masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
|
211 |
-
(min(self.max_length, emb.shape[1] - 217)),
|
212 |
-
dtype=torch.long,
|
213 |
-
device=torch.cuda.current_device(),
|
214 |
-
)
|
215 |
-
|
216 |
-
return embs, masks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
connector_edit.cpython-310.pyc
DELETED
Binary file (11.8 kB)
|
|
connector_edit.py
DELETED
@@ -1,486 +0,0 @@
|
|
1 |
-
from typing import Optional
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import torch.nn
|
5 |
-
from einops import rearrange
|
6 |
-
from torch import nn
|
7 |
-
|
8 |
-
from .layers import MLP, TextProjection, TimestepEmbedder, apply_gate, attention
|
9 |
-
|
10 |
-
|
11 |
-
class RMSNorm(nn.Module):
|
12 |
-
def __init__(
|
13 |
-
self,
|
14 |
-
dim: int,
|
15 |
-
elementwise_affine=True,
|
16 |
-
eps: float = 1e-6,
|
17 |
-
device=None,
|
18 |
-
dtype=None,
|
19 |
-
):
|
20 |
-
"""
|
21 |
-
Initialize the RMSNorm normalization layer.
|
22 |
-
|
23 |
-
Args:
|
24 |
-
dim (int): The dimension of the input tensor.
|
25 |
-
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
26 |
-
|
27 |
-
Attributes:
|
28 |
-
eps (float): A small value added to the denominator for numerical stability.
|
29 |
-
weight (nn.Parameter): Learnable scaling parameter.
|
30 |
-
|
31 |
-
"""
|
32 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
33 |
-
super().__init__()
|
34 |
-
self.eps = eps
|
35 |
-
if elementwise_affine:
|
36 |
-
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
37 |
-
|
38 |
-
def _norm(self, x):
|
39 |
-
"""
|
40 |
-
Apply the RMSNorm normalization to the input tensor.
|
41 |
-
|
42 |
-
Args:
|
43 |
-
x (torch.Tensor): The input tensor.
|
44 |
-
|
45 |
-
Returns:
|
46 |
-
torch.Tensor: The normalized tensor.
|
47 |
-
|
48 |
-
"""
|
49 |
-
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
50 |
-
|
51 |
-
def forward(self, x):
|
52 |
-
"""
|
53 |
-
Forward pass through the RMSNorm layer.
|
54 |
-
|
55 |
-
Args:
|
56 |
-
x (torch.Tensor): The input tensor.
|
57 |
-
|
58 |
-
Returns:
|
59 |
-
torch.Tensor: The output tensor after applying RMSNorm.
|
60 |
-
|
61 |
-
"""
|
62 |
-
output = self._norm(x.float()).type_as(x)
|
63 |
-
if hasattr(self, "weight"):
|
64 |
-
output = output * self.weight
|
65 |
-
return output
|
66 |
-
|
67 |
-
|
68 |
-
def get_norm_layer(norm_layer):
|
69 |
-
"""
|
70 |
-
Get the normalization layer.
|
71 |
-
|
72 |
-
Args:
|
73 |
-
norm_layer (str): The type of normalization layer.
|
74 |
-
|
75 |
-
Returns:
|
76 |
-
norm_layer (nn.Module): The normalization layer.
|
77 |
-
"""
|
78 |
-
if norm_layer == "layer":
|
79 |
-
return nn.LayerNorm
|
80 |
-
elif norm_layer == "rms":
|
81 |
-
return RMSNorm
|
82 |
-
else:
|
83 |
-
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
|
84 |
-
|
85 |
-
|
86 |
-
def get_activation_layer(act_type):
|
87 |
-
"""get activation layer
|
88 |
-
|
89 |
-
Args:
|
90 |
-
act_type (str): the activation type
|
91 |
-
|
92 |
-
Returns:
|
93 |
-
torch.nn.functional: the activation layer
|
94 |
-
"""
|
95 |
-
if act_type == "gelu":
|
96 |
-
return lambda: nn.GELU()
|
97 |
-
elif act_type == "gelu_tanh":
|
98 |
-
return lambda: nn.GELU(approximate="tanh")
|
99 |
-
elif act_type == "relu":
|
100 |
-
return nn.ReLU
|
101 |
-
elif act_type == "silu":
|
102 |
-
return nn.SiLU
|
103 |
-
else:
|
104 |
-
raise ValueError(f"Unknown activation type: {act_type}")
|
105 |
-
|
106 |
-
class IndividualTokenRefinerBlock(torch.nn.Module):
|
107 |
-
def __init__(
|
108 |
-
self,
|
109 |
-
hidden_size,
|
110 |
-
heads_num,
|
111 |
-
mlp_width_ratio: str = 4.0,
|
112 |
-
mlp_drop_rate: float = 0.0,
|
113 |
-
act_type: str = "silu",
|
114 |
-
qk_norm: bool = False,
|
115 |
-
qk_norm_type: str = "layer",
|
116 |
-
qkv_bias: bool = True,
|
117 |
-
need_CA: bool = False,
|
118 |
-
dtype: Optional[torch.dtype] = None,
|
119 |
-
device: Optional[torch.device] = None,
|
120 |
-
):
|
121 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
122 |
-
super().__init__()
|
123 |
-
self.need_CA = need_CA
|
124 |
-
self.heads_num = heads_num
|
125 |
-
head_dim = hidden_size // heads_num
|
126 |
-
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
127 |
-
|
128 |
-
self.norm1 = nn.LayerNorm(
|
129 |
-
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
130 |
-
)
|
131 |
-
self.self_attn_qkv = nn.Linear(
|
132 |
-
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
|
133 |
-
)
|
134 |
-
qk_norm_layer = get_norm_layer(qk_norm_type)
|
135 |
-
self.self_attn_q_norm = (
|
136 |
-
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
137 |
-
if qk_norm
|
138 |
-
else nn.Identity()
|
139 |
-
)
|
140 |
-
self.self_attn_k_norm = (
|
141 |
-
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
142 |
-
if qk_norm
|
143 |
-
else nn.Identity()
|
144 |
-
)
|
145 |
-
self.self_attn_proj = nn.Linear(
|
146 |
-
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
147 |
-
)
|
148 |
-
|
149 |
-
self.norm2 = nn.LayerNorm(
|
150 |
-
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
151 |
-
)
|
152 |
-
act_layer = get_activation_layer(act_type)
|
153 |
-
self.mlp = MLP(
|
154 |
-
in_channels=hidden_size,
|
155 |
-
hidden_channels=mlp_hidden_dim,
|
156 |
-
act_layer=act_layer,
|
157 |
-
drop=mlp_drop_rate,
|
158 |
-
**factory_kwargs,
|
159 |
-
)
|
160 |
-
|
161 |
-
self.adaLN_modulation = nn.Sequential(
|
162 |
-
act_layer(),
|
163 |
-
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
|
164 |
-
)
|
165 |
-
|
166 |
-
if self.need_CA:
|
167 |
-
self.cross_attnblock=CrossAttnBlock(hidden_size=hidden_size,
|
168 |
-
heads_num=heads_num,
|
169 |
-
mlp_width_ratio=mlp_width_ratio,
|
170 |
-
mlp_drop_rate=mlp_drop_rate,
|
171 |
-
act_type=act_type,
|
172 |
-
qk_norm=qk_norm,
|
173 |
-
qk_norm_type=qk_norm_type,
|
174 |
-
qkv_bias=qkv_bias,
|
175 |
-
**factory_kwargs,)
|
176 |
-
# Zero-initialize the modulation
|
177 |
-
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
178 |
-
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
179 |
-
|
180 |
-
def forward(
|
181 |
-
self,
|
182 |
-
x: torch.Tensor,
|
183 |
-
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
|
184 |
-
attn_mask: torch.Tensor = None,
|
185 |
-
y: torch.Tensor = None,
|
186 |
-
):
|
187 |
-
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
188 |
-
|
189 |
-
norm_x = self.norm1(x)
|
190 |
-
qkv = self.self_attn_qkv(norm_x)
|
191 |
-
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
192 |
-
# Apply QK-Norm if needed
|
193 |
-
q = self.self_attn_q_norm(q).to(v)
|
194 |
-
k = self.self_attn_k_norm(k).to(v)
|
195 |
-
|
196 |
-
# Self-Attention
|
197 |
-
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
|
198 |
-
|
199 |
-
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
|
200 |
-
|
201 |
-
if self.need_CA:
|
202 |
-
x = self.cross_attnblock(x, c, attn_mask, y)
|
203 |
-
|
204 |
-
# FFN Layer
|
205 |
-
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
|
206 |
-
|
207 |
-
return x
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
class CrossAttnBlock(torch.nn.Module):
|
213 |
-
def __init__(
|
214 |
-
self,
|
215 |
-
hidden_size,
|
216 |
-
heads_num,
|
217 |
-
mlp_width_ratio: str = 4.0,
|
218 |
-
mlp_drop_rate: float = 0.0,
|
219 |
-
act_type: str = "silu",
|
220 |
-
qk_norm: bool = False,
|
221 |
-
qk_norm_type: str = "layer",
|
222 |
-
qkv_bias: bool = True,
|
223 |
-
dtype: Optional[torch.dtype] = None,
|
224 |
-
device: Optional[torch.device] = None,
|
225 |
-
):
|
226 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
227 |
-
super().__init__()
|
228 |
-
self.heads_num = heads_num
|
229 |
-
head_dim = hidden_size // heads_num
|
230 |
-
|
231 |
-
self.norm1 = nn.LayerNorm(
|
232 |
-
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
233 |
-
)
|
234 |
-
self.norm1_2 = nn.LayerNorm(
|
235 |
-
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
236 |
-
)
|
237 |
-
self.self_attn_q = nn.Linear(
|
238 |
-
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
239 |
-
)
|
240 |
-
self.self_attn_kv = nn.Linear(
|
241 |
-
hidden_size, hidden_size*2, bias=qkv_bias, **factory_kwargs
|
242 |
-
)
|
243 |
-
qk_norm_layer = get_norm_layer(qk_norm_type)
|
244 |
-
self.self_attn_q_norm = (
|
245 |
-
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
246 |
-
if qk_norm
|
247 |
-
else nn.Identity()
|
248 |
-
)
|
249 |
-
self.self_attn_k_norm = (
|
250 |
-
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
251 |
-
if qk_norm
|
252 |
-
else nn.Identity()
|
253 |
-
)
|
254 |
-
self.self_attn_proj = nn.Linear(
|
255 |
-
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
256 |
-
)
|
257 |
-
|
258 |
-
self.norm2 = nn.LayerNorm(
|
259 |
-
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
260 |
-
)
|
261 |
-
act_layer = get_activation_layer(act_type)
|
262 |
-
|
263 |
-
self.adaLN_modulation = nn.Sequential(
|
264 |
-
act_layer(),
|
265 |
-
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
|
266 |
-
)
|
267 |
-
# Zero-initialize the modulation
|
268 |
-
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
269 |
-
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
270 |
-
|
271 |
-
def forward(
|
272 |
-
self,
|
273 |
-
x: torch.Tensor,
|
274 |
-
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
|
275 |
-
attn_mask: torch.Tensor = None,
|
276 |
-
y: torch.Tensor=None,
|
277 |
-
|
278 |
-
):
|
279 |
-
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
280 |
-
|
281 |
-
norm_x = self.norm1(x)
|
282 |
-
norm_y = self.norm1_2(y)
|
283 |
-
q = self.self_attn_q(norm_x)
|
284 |
-
q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num)
|
285 |
-
kv = self.self_attn_kv(norm_y)
|
286 |
-
k, v = rearrange(kv, "B L (K H D) -> K B L H D", K=2, H=self.heads_num)
|
287 |
-
# Apply QK-Norm if needed
|
288 |
-
q = self.self_attn_q_norm(q).to(v)
|
289 |
-
k = self.self_attn_k_norm(k).to(v)
|
290 |
-
|
291 |
-
# Self-Attention
|
292 |
-
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
|
293 |
-
|
294 |
-
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
|
295 |
-
|
296 |
-
return x
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
class IndividualTokenRefiner(torch.nn.Module):
|
301 |
-
def __init__(
|
302 |
-
self,
|
303 |
-
hidden_size,
|
304 |
-
heads_num,
|
305 |
-
depth,
|
306 |
-
mlp_width_ratio: float = 4.0,
|
307 |
-
mlp_drop_rate: float = 0.0,
|
308 |
-
act_type: str = "silu",
|
309 |
-
qk_norm: bool = False,
|
310 |
-
qk_norm_type: str = "layer",
|
311 |
-
qkv_bias: bool = True,
|
312 |
-
need_CA:bool=False,
|
313 |
-
dtype: Optional[torch.dtype] = None,
|
314 |
-
device: Optional[torch.device] = None,
|
315 |
-
):
|
316 |
-
|
317 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
318 |
-
super().__init__()
|
319 |
-
self.need_CA = need_CA
|
320 |
-
self.blocks = nn.ModuleList(
|
321 |
-
[
|
322 |
-
IndividualTokenRefinerBlock(
|
323 |
-
hidden_size=hidden_size,
|
324 |
-
heads_num=heads_num,
|
325 |
-
mlp_width_ratio=mlp_width_ratio,
|
326 |
-
mlp_drop_rate=mlp_drop_rate,
|
327 |
-
act_type=act_type,
|
328 |
-
qk_norm=qk_norm,
|
329 |
-
qk_norm_type=qk_norm_type,
|
330 |
-
qkv_bias=qkv_bias,
|
331 |
-
need_CA=self.need_CA,
|
332 |
-
**factory_kwargs,
|
333 |
-
)
|
334 |
-
for _ in range(depth)
|
335 |
-
]
|
336 |
-
)
|
337 |
-
|
338 |
-
|
339 |
-
def forward(
|
340 |
-
self,
|
341 |
-
x: torch.Tensor,
|
342 |
-
c: torch.LongTensor,
|
343 |
-
mask: Optional[torch.Tensor] = None,
|
344 |
-
y:torch.Tensor=None,
|
345 |
-
):
|
346 |
-
self_attn_mask = None
|
347 |
-
if mask is not None:
|
348 |
-
batch_size = mask.shape[0]
|
349 |
-
seq_len = mask.shape[1]
|
350 |
-
mask = mask.to(x.device)
|
351 |
-
# batch_size x 1 x seq_len x seq_len
|
352 |
-
self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
|
353 |
-
1, 1, seq_len, 1
|
354 |
-
)
|
355 |
-
# batch_size x 1 x seq_len x seq_len
|
356 |
-
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
357 |
-
# batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
|
358 |
-
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
|
359 |
-
# avoids self-attention weight being NaN for padding tokens
|
360 |
-
self_attn_mask[:, :, :, 0] = True
|
361 |
-
|
362 |
-
|
363 |
-
for block in self.blocks:
|
364 |
-
x = block(x, c, self_attn_mask,y)
|
365 |
-
|
366 |
-
return x
|
367 |
-
|
368 |
-
|
369 |
-
class SingleTokenRefiner(torch.nn.Module):
|
370 |
-
"""
|
371 |
-
A single token refiner block for llm text embedding refine.
|
372 |
-
"""
|
373 |
-
def __init__(
|
374 |
-
self,
|
375 |
-
in_channels,
|
376 |
-
hidden_size,
|
377 |
-
heads_num,
|
378 |
-
depth,
|
379 |
-
mlp_width_ratio: float = 4.0,
|
380 |
-
mlp_drop_rate: float = 0.0,
|
381 |
-
act_type: str = "silu",
|
382 |
-
qk_norm: bool = False,
|
383 |
-
qk_norm_type: str = "layer",
|
384 |
-
qkv_bias: bool = True,
|
385 |
-
need_CA:bool=False,
|
386 |
-
attn_mode: str = "torch",
|
387 |
-
dtype: Optional[torch.dtype] = None,
|
388 |
-
device: Optional[torch.device] = None,
|
389 |
-
):
|
390 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
391 |
-
super().__init__()
|
392 |
-
self.attn_mode = attn_mode
|
393 |
-
self.need_CA = need_CA
|
394 |
-
assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
|
395 |
-
|
396 |
-
self.input_embedder = nn.Linear(
|
397 |
-
in_channels, hidden_size, bias=True, **factory_kwargs
|
398 |
-
)
|
399 |
-
if self.need_CA:
|
400 |
-
self.input_embedder_CA = nn.Linear(
|
401 |
-
in_channels, hidden_size, bias=True, **factory_kwargs
|
402 |
-
)
|
403 |
-
|
404 |
-
act_layer = get_activation_layer(act_type)
|
405 |
-
# Build timestep embedding layer
|
406 |
-
self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
|
407 |
-
# Build context embedding layer
|
408 |
-
self.c_embedder = TextProjection(
|
409 |
-
in_channels, hidden_size, act_layer, **factory_kwargs
|
410 |
-
)
|
411 |
-
|
412 |
-
self.individual_token_refiner = IndividualTokenRefiner(
|
413 |
-
hidden_size=hidden_size,
|
414 |
-
heads_num=heads_num,
|
415 |
-
depth=depth,
|
416 |
-
mlp_width_ratio=mlp_width_ratio,
|
417 |
-
mlp_drop_rate=mlp_drop_rate,
|
418 |
-
act_type=act_type,
|
419 |
-
qk_norm=qk_norm,
|
420 |
-
qk_norm_type=qk_norm_type,
|
421 |
-
qkv_bias=qkv_bias,
|
422 |
-
need_CA=need_CA,
|
423 |
-
**factory_kwargs,
|
424 |
-
)
|
425 |
-
|
426 |
-
def forward(
|
427 |
-
self,
|
428 |
-
x: torch.Tensor,
|
429 |
-
t: torch.LongTensor,
|
430 |
-
mask: Optional[torch.LongTensor] = None,
|
431 |
-
y: torch.LongTensor=None,
|
432 |
-
):
|
433 |
-
timestep_aware_representations = self.t_embedder(t)
|
434 |
-
|
435 |
-
if mask is None:
|
436 |
-
context_aware_representations = x.mean(dim=1)
|
437 |
-
else:
|
438 |
-
mask_float = mask.unsqueeze(-1) # [b, s1, 1]
|
439 |
-
context_aware_representations = (x * mask_float).sum(
|
440 |
-
dim=1
|
441 |
-
) / mask_float.sum(dim=1)
|
442 |
-
context_aware_representations = self.c_embedder(context_aware_representations)
|
443 |
-
c = timestep_aware_representations + context_aware_representations
|
444 |
-
|
445 |
-
x = self.input_embedder(x)
|
446 |
-
if self.need_CA:
|
447 |
-
y = self.input_embedder_CA(y)
|
448 |
-
x = self.individual_token_refiner(x, c, mask, y)
|
449 |
-
else:
|
450 |
-
x = self.individual_token_refiner(x, c, mask)
|
451 |
-
|
452 |
-
return x
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
class Qwen2Connector(torch.nn.Module):
|
457 |
-
def __init__(
|
458 |
-
self,
|
459 |
-
# biclip_dim=1024,
|
460 |
-
in_channels=3584,
|
461 |
-
hidden_size=4096,
|
462 |
-
heads_num=32,
|
463 |
-
depth=2,
|
464 |
-
need_CA=False,
|
465 |
-
device=None,
|
466 |
-
dtype=torch.bfloat16,
|
467 |
-
):
|
468 |
-
super().__init__()
|
469 |
-
factory_kwargs = {"device": device, "dtype":dtype}
|
470 |
-
|
471 |
-
self.S =SingleTokenRefiner(in_channels=in_channels,hidden_size=hidden_size,heads_num=heads_num,depth=depth,need_CA=need_CA,**factory_kwargs)
|
472 |
-
self.global_proj_out=nn.Linear(in_channels,768)
|
473 |
-
|
474 |
-
self.scale_factor = nn.Parameter(torch.zeros(1))
|
475 |
-
with torch.no_grad():
|
476 |
-
self.scale_factor.data += -(1 - 0.09)
|
477 |
-
|
478 |
-
def forward(self, x,t,mask):
|
479 |
-
mask_float = mask.unsqueeze(-1) # [b, s1, 1]
|
480 |
-
x_mean = (x * mask_float).sum(
|
481 |
-
dim=1
|
482 |
-
) / mask_float.sum(dim=1) * (1 + self.scale_factor)
|
483 |
-
|
484 |
-
global_out=self.global_proj_out(x_mean)
|
485 |
-
encoder_hidden_states = self.S(x,t,mask)
|
486 |
-
return encoder_hidden_states,global_out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cookie.png → examples 2.zip
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5de0f67d94e0e46599bc9619a912a05898b8053ddc0f1f6563a3ee3b4dd1f7c7
|
3 |
+
size 1878523
|
examples 2/celeb_meme.jpg
DELETED
Git LFS Details
|
examples 2/cookie.png
DELETED
Git LFS Details
|
examples 2/ghibli_meme.jpg
DELETED
Binary file (38.1 kB)
|
|
examples 2/leather.jpg
DELETED
Git LFS Details
|
examples 2/meme.jpg
DELETED
Binary file (49.8 kB)
|
|
examples 2/no_cookie.png
DELETED
Git LFS Details
|
examples 2/poster.jpg
DELETED
Binary file (65.4 kB)
|
|
examples 2/poster_orig.jpg
DELETED
Git LFS Details
|
ghibli_meme.jpg
DELETED
Binary file (38.1 kB)
|
|
layers.cpython-310.pyc
DELETED
Binary file (19.1 kB)
|
|
layers.py
DELETED
@@ -1,640 +0,0 @@
|
|
1 |
-
# Modified from Flux
|
2 |
-
#
|
3 |
-
# Copyright 2024 Black Forest Labs
|
4 |
-
|
5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
-
# you may not use this file except in compliance with the License.
|
7 |
-
# You may obtain a copy of the License at
|
8 |
-
|
9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
-
|
11 |
-
# Unless required by applicable law or agreed to in writing, software
|
12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
-
# See the License for the specific language governing permissions and
|
15 |
-
# limitations under the License.
|
16 |
-
#
|
17 |
-
# This source code is licensed under the license found in the
|
18 |
-
# LICENSE file in the root directory of this source tree.
|
19 |
-
|
20 |
-
import math # noqa: I001
|
21 |
-
from dataclasses import dataclass
|
22 |
-
from functools import partial
|
23 |
-
|
24 |
-
import torch
|
25 |
-
import torch.nn.functional as F
|
26 |
-
from einops import rearrange
|
27 |
-
# from liger_kernel.ops.rms_norm import LigerRMSNormFunction
|
28 |
-
from torch import Tensor, nn
|
29 |
-
|
30 |
-
|
31 |
-
try:
|
32 |
-
import flash_attn
|
33 |
-
from flash_attn.flash_attn_interface import (
|
34 |
-
_flash_attn_forward,
|
35 |
-
flash_attn_varlen_func,
|
36 |
-
)
|
37 |
-
except ImportError:
|
38 |
-
flash_attn = None
|
39 |
-
flash_attn_varlen_func = None
|
40 |
-
_flash_attn_forward = None
|
41 |
-
|
42 |
-
|
43 |
-
MEMORY_LAYOUT = {
|
44 |
-
"flash": (
|
45 |
-
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
|
46 |
-
lambda x: x,
|
47 |
-
),
|
48 |
-
"torch": (
|
49 |
-
lambda x: x.transpose(1, 2),
|
50 |
-
lambda x: x.transpose(1, 2),
|
51 |
-
),
|
52 |
-
"vanilla": (
|
53 |
-
lambda x: x.transpose(1, 2),
|
54 |
-
lambda x: x.transpose(1, 2),
|
55 |
-
),
|
56 |
-
}
|
57 |
-
|
58 |
-
|
59 |
-
def attention(
|
60 |
-
q,
|
61 |
-
k,
|
62 |
-
v,
|
63 |
-
mode="torch",
|
64 |
-
drop_rate=0,
|
65 |
-
attn_mask=None,
|
66 |
-
causal=False,
|
67 |
-
cu_seqlens_q=None,
|
68 |
-
cu_seqlens_kv=None,
|
69 |
-
max_seqlen_q=None,
|
70 |
-
max_seqlen_kv=None,
|
71 |
-
batch_size=1,
|
72 |
-
):
|
73 |
-
"""
|
74 |
-
Perform QKV self attention.
|
75 |
-
|
76 |
-
Args:
|
77 |
-
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
|
78 |
-
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
|
79 |
-
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
|
80 |
-
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
|
81 |
-
drop_rate (float): Dropout rate in attention map. (default: 0)
|
82 |
-
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
|
83 |
-
(default: None)
|
84 |
-
causal (bool): Whether to use causal attention. (default: False)
|
85 |
-
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
86 |
-
used to index into q.
|
87 |
-
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
88 |
-
used to index into kv.
|
89 |
-
max_seqlen_q (int): The maximum sequence length in the batch of q.
|
90 |
-
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
|
91 |
-
|
92 |
-
Returns:
|
93 |
-
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
|
94 |
-
"""
|
95 |
-
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
|
96 |
-
q = pre_attn_layout(q)
|
97 |
-
k = pre_attn_layout(k)
|
98 |
-
v = pre_attn_layout(v)
|
99 |
-
|
100 |
-
if mode == "torch":
|
101 |
-
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
102 |
-
attn_mask = attn_mask.to(q.dtype)
|
103 |
-
x = F.scaled_dot_product_attention(
|
104 |
-
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
|
105 |
-
)
|
106 |
-
elif mode == "flash":
|
107 |
-
assert flash_attn_varlen_func is not None
|
108 |
-
x: torch.Tensor = flash_attn_varlen_func(
|
109 |
-
q,
|
110 |
-
k,
|
111 |
-
v,
|
112 |
-
cu_seqlens_q,
|
113 |
-
cu_seqlens_kv,
|
114 |
-
max_seqlen_q,
|
115 |
-
max_seqlen_kv,
|
116 |
-
) # type: ignore
|
117 |
-
# x with shape [(bxs), a, d]
|
118 |
-
x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # type: ignore # reshape x to [b, s, a, d]
|
119 |
-
elif mode == "vanilla":
|
120 |
-
scale_factor = 1 / math.sqrt(q.size(-1))
|
121 |
-
|
122 |
-
b, a, s, _ = q.shape
|
123 |
-
s1 = k.size(2)
|
124 |
-
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
|
125 |
-
if causal:
|
126 |
-
# Only applied to self attention
|
127 |
-
assert attn_mask is None, (
|
128 |
-
"Causal mask and attn_mask cannot be used together"
|
129 |
-
)
|
130 |
-
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
|
131 |
-
diagonal=0
|
132 |
-
)
|
133 |
-
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
134 |
-
attn_bias.to(q.dtype)
|
135 |
-
|
136 |
-
if attn_mask is not None:
|
137 |
-
if attn_mask.dtype == torch.bool:
|
138 |
-
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
139 |
-
else:
|
140 |
-
attn_bias += attn_mask
|
141 |
-
|
142 |
-
# TODO: Maybe force q and k to be float32 to avoid numerical overflow
|
143 |
-
attn = (q @ k.transpose(-2, -1)) * scale_factor
|
144 |
-
attn += attn_bias
|
145 |
-
attn = attn.softmax(dim=-1)
|
146 |
-
attn = torch.dropout(attn, p=drop_rate, train=True)
|
147 |
-
x = attn @ v
|
148 |
-
else:
|
149 |
-
raise NotImplementedError(f"Unsupported attention mode: {mode}")
|
150 |
-
|
151 |
-
x = post_attn_layout(x)
|
152 |
-
b, s, a, d = x.shape
|
153 |
-
out = x.reshape(b, s, -1)
|
154 |
-
return out
|
155 |
-
|
156 |
-
|
157 |
-
def apply_gate(x, gate=None, tanh=False):
|
158 |
-
"""AI is creating summary for apply_gate
|
159 |
-
|
160 |
-
Args:
|
161 |
-
x (torch.Tensor): input tensor.
|
162 |
-
gate (torch.Tensor, optional): gate tensor. Defaults to None.
|
163 |
-
tanh (bool, optional): whether to use tanh function. Defaults to False.
|
164 |
-
|
165 |
-
Returns:
|
166 |
-
torch.Tensor: the output tensor after apply gate.
|
167 |
-
"""
|
168 |
-
if gate is None:
|
169 |
-
return x
|
170 |
-
if tanh:
|
171 |
-
return x * gate.unsqueeze(1).tanh()
|
172 |
-
else:
|
173 |
-
return x * gate.unsqueeze(1)
|
174 |
-
|
175 |
-
|
176 |
-
class MLP(nn.Module):
|
177 |
-
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
178 |
-
|
179 |
-
def __init__(
|
180 |
-
self,
|
181 |
-
in_channels,
|
182 |
-
hidden_channels=None,
|
183 |
-
out_features=None,
|
184 |
-
act_layer=nn.GELU,
|
185 |
-
norm_layer=None,
|
186 |
-
bias=True,
|
187 |
-
drop=0.0,
|
188 |
-
use_conv=False,
|
189 |
-
device=None,
|
190 |
-
dtype=None,
|
191 |
-
):
|
192 |
-
super().__init__()
|
193 |
-
out_features = out_features or in_channels
|
194 |
-
hidden_channels = hidden_channels or in_channels
|
195 |
-
bias = (bias, bias)
|
196 |
-
drop_probs = (drop, drop)
|
197 |
-
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
198 |
-
|
199 |
-
self.fc1 = linear_layer(
|
200 |
-
in_channels, hidden_channels, bias=bias[0], device=device, dtype=dtype
|
201 |
-
)
|
202 |
-
self.act = act_layer()
|
203 |
-
self.drop1 = nn.Dropout(drop_probs[0])
|
204 |
-
self.norm = (
|
205 |
-
norm_layer(hidden_channels, device=device, dtype=dtype)
|
206 |
-
if norm_layer is not None
|
207 |
-
else nn.Identity()
|
208 |
-
)
|
209 |
-
self.fc2 = linear_layer(
|
210 |
-
hidden_channels, out_features, bias=bias[1], device=device, dtype=dtype
|
211 |
-
)
|
212 |
-
self.drop2 = nn.Dropout(drop_probs[1])
|
213 |
-
|
214 |
-
def forward(self, x):
|
215 |
-
x = self.fc1(x)
|
216 |
-
x = self.act(x)
|
217 |
-
x = self.drop1(x)
|
218 |
-
x = self.norm(x)
|
219 |
-
x = self.fc2(x)
|
220 |
-
x = self.drop2(x)
|
221 |
-
return x
|
222 |
-
|
223 |
-
|
224 |
-
class TextProjection(nn.Module):
|
225 |
-
"""
|
226 |
-
Projects text embeddings. Also handles dropout for classifier-free guidance.
|
227 |
-
|
228 |
-
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
229 |
-
"""
|
230 |
-
|
231 |
-
def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
|
232 |
-
factory_kwargs = {"dtype": dtype, "device": device}
|
233 |
-
super().__init__()
|
234 |
-
self.linear_1 = nn.Linear(
|
235 |
-
in_features=in_channels,
|
236 |
-
out_features=hidden_size,
|
237 |
-
bias=True,
|
238 |
-
**factory_kwargs,
|
239 |
-
)
|
240 |
-
self.act_1 = act_layer()
|
241 |
-
self.linear_2 = nn.Linear(
|
242 |
-
in_features=hidden_size,
|
243 |
-
out_features=hidden_size,
|
244 |
-
bias=True,
|
245 |
-
**factory_kwargs,
|
246 |
-
)
|
247 |
-
|
248 |
-
def forward(self, caption):
|
249 |
-
hidden_states = self.linear_1(caption)
|
250 |
-
hidden_states = self.act_1(hidden_states)
|
251 |
-
hidden_states = self.linear_2(hidden_states)
|
252 |
-
return hidden_states
|
253 |
-
|
254 |
-
|
255 |
-
class TimestepEmbedder(nn.Module):
|
256 |
-
"""
|
257 |
-
Embeds scalar timesteps into vector representations.
|
258 |
-
"""
|
259 |
-
|
260 |
-
def __init__(
|
261 |
-
self,
|
262 |
-
hidden_size,
|
263 |
-
act_layer,
|
264 |
-
frequency_embedding_size=256,
|
265 |
-
max_period=10000,
|
266 |
-
out_size=None,
|
267 |
-
dtype=None,
|
268 |
-
device=None,
|
269 |
-
):
|
270 |
-
factory_kwargs = {"dtype": dtype, "device": device}
|
271 |
-
super().__init__()
|
272 |
-
self.frequency_embedding_size = frequency_embedding_size
|
273 |
-
self.max_period = max_period
|
274 |
-
if out_size is None:
|
275 |
-
out_size = hidden_size
|
276 |
-
|
277 |
-
self.mlp = nn.Sequential(
|
278 |
-
nn.Linear(
|
279 |
-
frequency_embedding_size, hidden_size, bias=True, **factory_kwargs
|
280 |
-
),
|
281 |
-
act_layer(),
|
282 |
-
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
|
283 |
-
)
|
284 |
-
nn.init.normal_(self.mlp[0].weight, std=0.02) # type: ignore
|
285 |
-
nn.init.normal_(self.mlp[2].weight, std=0.02) # type: ignore
|
286 |
-
|
287 |
-
@staticmethod
|
288 |
-
def timestep_embedding(t, dim, max_period=10000):
|
289 |
-
"""
|
290 |
-
Create sinusoidal timestep embeddings.
|
291 |
-
|
292 |
-
Args:
|
293 |
-
t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
294 |
-
dim (int): the dimension of the output.
|
295 |
-
max_period (int): controls the minimum frequency of the embeddings.
|
296 |
-
|
297 |
-
Returns:
|
298 |
-
embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
|
299 |
-
|
300 |
-
.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
301 |
-
"""
|
302 |
-
half = dim // 2
|
303 |
-
freqs = torch.exp(
|
304 |
-
-math.log(max_period)
|
305 |
-
* torch.arange(start=0, end=half, dtype=torch.float32)
|
306 |
-
/ half
|
307 |
-
).to(device=t.device)
|
308 |
-
args = t[:, None].float() * freqs[None]
|
309 |
-
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
310 |
-
if dim % 2:
|
311 |
-
embedding = torch.cat(
|
312 |
-
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
313 |
-
)
|
314 |
-
return embedding
|
315 |
-
|
316 |
-
def forward(self, t):
|
317 |
-
t_freq = self.timestep_embedding(
|
318 |
-
t, self.frequency_embedding_size, self.max_period
|
319 |
-
).type(self.mlp[0].weight.dtype) # type: ignore
|
320 |
-
t_emb = self.mlp(t_freq)
|
321 |
-
return t_emb
|
322 |
-
|
323 |
-
|
324 |
-
class EmbedND(nn.Module):
|
325 |
-
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
326 |
-
super().__init__()
|
327 |
-
self.dim = dim
|
328 |
-
self.theta = theta
|
329 |
-
self.axes_dim = axes_dim
|
330 |
-
|
331 |
-
def forward(self, ids: Tensor) -> Tensor:
|
332 |
-
n_axes = ids.shape[-1]
|
333 |
-
emb = torch.cat(
|
334 |
-
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
335 |
-
dim=-3,
|
336 |
-
)
|
337 |
-
|
338 |
-
return emb.unsqueeze(1)
|
339 |
-
|
340 |
-
|
341 |
-
class MLPEmbedder(nn.Module):
|
342 |
-
def __init__(self, in_dim: int, hidden_dim: int):
|
343 |
-
super().__init__()
|
344 |
-
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
345 |
-
self.silu = nn.SiLU()
|
346 |
-
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
347 |
-
|
348 |
-
def forward(self, x: Tensor) -> Tensor:
|
349 |
-
return self.out_layer(self.silu(self.in_layer(x)))
|
350 |
-
|
351 |
-
|
352 |
-
def rope(pos, dim: int, theta: int):
|
353 |
-
assert dim % 2 == 0
|
354 |
-
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
355 |
-
omega = 1.0 / (theta**scale)
|
356 |
-
out = torch.einsum("...n,d->...nd", pos, omega)
|
357 |
-
out = torch.stack(
|
358 |
-
[torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
|
359 |
-
)
|
360 |
-
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
361 |
-
return out.float()
|
362 |
-
|
363 |
-
|
364 |
-
def attention_after_rope(q, k, v, pe):
|
365 |
-
q, k = apply_rope(q, k, pe)
|
366 |
-
|
367 |
-
from .attention import attention
|
368 |
-
|
369 |
-
x = attention(q, k, v, mode="torch")
|
370 |
-
return x
|
371 |
-
|
372 |
-
|
373 |
-
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
|
374 |
-
def apply_rope(xq, xk, freqs_cis):
|
375 |
-
# 将 num_heads 和 seq_len 的维度交换回原函数的处理顺序
|
376 |
-
xq = xq.transpose(1, 2) # [batch, num_heads, seq_len, head_dim]
|
377 |
-
xk = xk.transpose(1, 2)
|
378 |
-
|
379 |
-
# 将 head_dim 拆分为复数部分(实部和虚部)
|
380 |
-
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
381 |
-
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
382 |
-
|
383 |
-
# 应用旋转位置编码(复数乘法)
|
384 |
-
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
385 |
-
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
386 |
-
|
387 |
-
# 恢复张量形状并转置回目标维度顺序
|
388 |
-
xq_out = xq_out.reshape(*xq.shape).type_as(xq).transpose(1, 2)
|
389 |
-
xk_out = xk_out.reshape(*xk.shape).type_as(xk).transpose(1, 2)
|
390 |
-
|
391 |
-
return xq_out, xk_out
|
392 |
-
|
393 |
-
|
394 |
-
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
|
395 |
-
def scale_add_residual(
|
396 |
-
x: torch.Tensor, scale: torch.Tensor, residual: torch.Tensor
|
397 |
-
) -> torch.Tensor:
|
398 |
-
return x * scale + residual
|
399 |
-
|
400 |
-
|
401 |
-
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
|
402 |
-
def layernorm_and_scale_shift(
|
403 |
-
x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor
|
404 |
-
) -> torch.Tensor:
|
405 |
-
return torch.nn.functional.layer_norm(x, (x.size(-1),)) * (scale + 1) + shift
|
406 |
-
|
407 |
-
|
408 |
-
class SelfAttention(nn.Module):
|
409 |
-
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
410 |
-
super().__init__()
|
411 |
-
self.num_heads = num_heads
|
412 |
-
head_dim = dim // num_heads
|
413 |
-
|
414 |
-
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
415 |
-
self.norm = QKNorm(head_dim)
|
416 |
-
self.proj = nn.Linear(dim, dim)
|
417 |
-
|
418 |
-
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
419 |
-
qkv = self.qkv(x)
|
420 |
-
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
|
421 |
-
q, k = self.norm(q, k, v)
|
422 |
-
x = attention_after_rope(q, k, v, pe=pe)
|
423 |
-
x = self.proj(x)
|
424 |
-
return x
|
425 |
-
|
426 |
-
|
427 |
-
@dataclass
|
428 |
-
class ModulationOut:
|
429 |
-
shift: Tensor
|
430 |
-
scale: Tensor
|
431 |
-
gate: Tensor
|
432 |
-
|
433 |
-
|
434 |
-
class RMSNorm(torch.nn.Module):
|
435 |
-
def __init__(self, dim: int):
|
436 |
-
super().__init__()
|
437 |
-
self.scale = nn.Parameter(torch.ones(dim))
|
438 |
-
|
439 |
-
# @staticmethod
|
440 |
-
# def rms_norm_fast(x, weight, eps):
|
441 |
-
# return LigerRMSNormFunction.apply(
|
442 |
-
# x,
|
443 |
-
# weight,
|
444 |
-
# eps,
|
445 |
-
# 0.0,
|
446 |
-
# "gemma",
|
447 |
-
# True,
|
448 |
-
# )
|
449 |
-
|
450 |
-
@staticmethod
|
451 |
-
def rms_norm(x, weight, eps):
|
452 |
-
x_dtype = x.dtype
|
453 |
-
x = x.float()
|
454 |
-
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
455 |
-
return (x * rrms).to(dtype=x_dtype) * weight
|
456 |
-
|
457 |
-
def forward(self, x: Tensor):
|
458 |
-
# return self.rms_norm_fast(x, self.scale, 1e-6)
|
459 |
-
return self.rms_norm(x, self.scale, 1e-6)
|
460 |
-
|
461 |
-
|
462 |
-
class QKNorm(torch.nn.Module):
|
463 |
-
def __init__(self, dim: int):
|
464 |
-
super().__init__()
|
465 |
-
self.query_norm = RMSNorm(dim)
|
466 |
-
self.key_norm = RMSNorm(dim)
|
467 |
-
|
468 |
-
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
469 |
-
q = self.query_norm(q)
|
470 |
-
k = self.key_norm(k)
|
471 |
-
return q.to(v), k.to(v)
|
472 |
-
|
473 |
-
|
474 |
-
class Modulation(nn.Module):
|
475 |
-
def __init__(self, dim: int, double: bool):
|
476 |
-
super().__init__()
|
477 |
-
self.is_double = double
|
478 |
-
self.multiplier = 6 if double else 3
|
479 |
-
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
480 |
-
|
481 |
-
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
482 |
-
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(
|
483 |
-
self.multiplier, dim=-1
|
484 |
-
)
|
485 |
-
|
486 |
-
return (
|
487 |
-
ModulationOut(*out[:3]),
|
488 |
-
ModulationOut(*out[3:]) if self.is_double else None,
|
489 |
-
)
|
490 |
-
|
491 |
-
|
492 |
-
class DoubleStreamBlock(nn.Module):
|
493 |
-
def __init__(
|
494 |
-
self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
|
495 |
-
):
|
496 |
-
super().__init__()
|
497 |
-
|
498 |
-
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
499 |
-
self.num_heads = num_heads
|
500 |
-
self.hidden_size = hidden_size
|
501 |
-
self.img_mod = Modulation(hidden_size, double=True)
|
502 |
-
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
503 |
-
self.img_attn = SelfAttention(
|
504 |
-
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
505 |
-
)
|
506 |
-
|
507 |
-
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
508 |
-
self.img_mlp = nn.Sequential(
|
509 |
-
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
510 |
-
nn.GELU(approximate="tanh"),
|
511 |
-
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
512 |
-
)
|
513 |
-
|
514 |
-
self.txt_mod = Modulation(hidden_size, double=True)
|
515 |
-
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
516 |
-
self.txt_attn = SelfAttention(
|
517 |
-
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
518 |
-
)
|
519 |
-
|
520 |
-
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
521 |
-
self.txt_mlp = nn.Sequential(
|
522 |
-
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
523 |
-
nn.GELU(approximate="tanh"),
|
524 |
-
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
525 |
-
)
|
526 |
-
|
527 |
-
def forward(
|
528 |
-
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
|
529 |
-
) -> tuple[Tensor, Tensor]:
|
530 |
-
img_mod1, img_mod2 = self.img_mod(vec)
|
531 |
-
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
532 |
-
|
533 |
-
# prepare image for attention
|
534 |
-
img_modulated = self.img_norm1(img)
|
535 |
-
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
536 |
-
img_qkv = self.img_attn.qkv(img_modulated)
|
537 |
-
img_q, img_k, img_v = rearrange(
|
538 |
-
img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads
|
539 |
-
)
|
540 |
-
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
541 |
-
|
542 |
-
# prepare txt for attention
|
543 |
-
txt_modulated = self.txt_norm1(txt)
|
544 |
-
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
545 |
-
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
546 |
-
txt_q, txt_k, txt_v = rearrange(
|
547 |
-
txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads
|
548 |
-
)
|
549 |
-
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
550 |
-
|
551 |
-
# run actual attention
|
552 |
-
q = torch.cat((txt_q, img_q), dim=1)
|
553 |
-
k = torch.cat((txt_k, img_k), dim=1)
|
554 |
-
v = torch.cat((txt_v, img_v), dim=1)
|
555 |
-
|
556 |
-
attn = attention_after_rope(q, k, v, pe=pe)
|
557 |
-
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
558 |
-
|
559 |
-
# calculate the img bloks
|
560 |
-
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
561 |
-
img_mlp = self.img_mlp(
|
562 |
-
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
|
563 |
-
)
|
564 |
-
img = scale_add_residual(img_mlp, img_mod2.gate, img)
|
565 |
-
|
566 |
-
# calculate the txt bloks
|
567 |
-
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
568 |
-
txt_mlp = self.txt_mlp(
|
569 |
-
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
|
570 |
-
)
|
571 |
-
txt = scale_add_residual(txt_mlp, txt_mod2.gate, txt)
|
572 |
-
return img, txt
|
573 |
-
|
574 |
-
|
575 |
-
class SingleStreamBlock(nn.Module):
|
576 |
-
"""
|
577 |
-
A DiT block with parallel linear layers as described in
|
578 |
-
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
579 |
-
"""
|
580 |
-
|
581 |
-
def __init__(
|
582 |
-
self,
|
583 |
-
hidden_size: int,
|
584 |
-
num_heads: int,
|
585 |
-
mlp_ratio: float = 4.0,
|
586 |
-
qk_scale: float | None = None,
|
587 |
-
):
|
588 |
-
super().__init__()
|
589 |
-
self.hidden_dim = hidden_size
|
590 |
-
self.num_heads = num_heads
|
591 |
-
head_dim = hidden_size // num_heads
|
592 |
-
self.scale = qk_scale or head_dim**-0.5
|
593 |
-
|
594 |
-
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
595 |
-
# qkv and mlp_in
|
596 |
-
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
597 |
-
# proj and mlp_out
|
598 |
-
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
599 |
-
|
600 |
-
self.norm = QKNorm(head_dim)
|
601 |
-
|
602 |
-
self.hidden_size = hidden_size
|
603 |
-
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
604 |
-
|
605 |
-
self.mlp_act = nn.GELU(approximate="tanh")
|
606 |
-
self.modulation = Modulation(hidden_size, double=False)
|
607 |
-
|
608 |
-
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
609 |
-
mod, _ = self.modulation(vec)
|
610 |
-
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
611 |
-
qkv, mlp = torch.split(
|
612 |
-
self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
|
613 |
-
)
|
614 |
-
|
615 |
-
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
|
616 |
-
q, k = self.norm(q, k, v)
|
617 |
-
|
618 |
-
# compute attention
|
619 |
-
attn = attention_after_rope(q, k, v, pe=pe)
|
620 |
-
# compute activation in mlp stream, cat again and run second linear layer
|
621 |
-
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
622 |
-
return scale_add_residual(output, mod.gate, x)
|
623 |
-
|
624 |
-
|
625 |
-
class LastLayer(nn.Module):
|
626 |
-
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
627 |
-
super().__init__()
|
628 |
-
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
629 |
-
self.linear = nn.Linear(
|
630 |
-
hidden_size, patch_size * patch_size * out_channels, bias=True
|
631 |
-
)
|
632 |
-
self.adaLN_modulation = nn.Sequential(
|
633 |
-
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
634 |
-
)
|
635 |
-
|
636 |
-
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
637 |
-
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
638 |
-
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
639 |
-
x = self.linear(x)
|
640 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
leather.jpg
DELETED
Git LFS Details
|
meme.jpg
DELETED
Binary file (49.8 kB)
|
|
model_edit.cpython-310.pyc
DELETED
Binary file (4.21 kB)
|
|
model_edit.py
DELETED
@@ -1,143 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
from dataclasses import dataclass
|
3 |
-
|
4 |
-
import numpy as np
|
5 |
-
import torch
|
6 |
-
from torch import Tensor, nn
|
7 |
-
|
8 |
-
from .connector_edit import Qwen2Connector
|
9 |
-
from .layers import DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock
|
10 |
-
|
11 |
-
|
12 |
-
@dataclass
|
13 |
-
class Step1XParams:
|
14 |
-
in_channels: int
|
15 |
-
out_channels: int
|
16 |
-
vec_in_dim: int
|
17 |
-
context_in_dim: int
|
18 |
-
hidden_size: int
|
19 |
-
mlp_ratio: float
|
20 |
-
num_heads: int
|
21 |
-
depth: int
|
22 |
-
depth_single_blocks: int
|
23 |
-
axes_dim: list[int]
|
24 |
-
theta: int
|
25 |
-
qkv_bias: bool
|
26 |
-
|
27 |
-
|
28 |
-
class Step1XEdit(nn.Module):
|
29 |
-
"""
|
30 |
-
Transformer model for flow matching on sequences.
|
31 |
-
"""
|
32 |
-
|
33 |
-
def __init__(self, params: Step1XParams):
|
34 |
-
super().__init__()
|
35 |
-
|
36 |
-
self.params = params
|
37 |
-
self.in_channels = params.in_channels
|
38 |
-
self.out_channels = params.out_channels
|
39 |
-
if params.hidden_size % params.num_heads != 0:
|
40 |
-
raise ValueError(
|
41 |
-
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
42 |
-
)
|
43 |
-
pe_dim = params.hidden_size // params.num_heads
|
44 |
-
if sum(params.axes_dim) != pe_dim:
|
45 |
-
raise ValueError(
|
46 |
-
f"Got {params.axes_dim} but expected positional dim {pe_dim}"
|
47 |
-
)
|
48 |
-
self.hidden_size = params.hidden_size
|
49 |
-
self.num_heads = params.num_heads
|
50 |
-
self.pe_embedder = EmbedND(
|
51 |
-
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
|
52 |
-
)
|
53 |
-
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
54 |
-
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
55 |
-
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
56 |
-
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
57 |
-
|
58 |
-
self.double_blocks = nn.ModuleList(
|
59 |
-
[
|
60 |
-
DoubleStreamBlock(
|
61 |
-
self.hidden_size,
|
62 |
-
self.num_heads,
|
63 |
-
mlp_ratio=params.mlp_ratio,
|
64 |
-
qkv_bias=params.qkv_bias,
|
65 |
-
)
|
66 |
-
for _ in range(params.depth)
|
67 |
-
]
|
68 |
-
)
|
69 |
-
|
70 |
-
self.single_blocks = nn.ModuleList(
|
71 |
-
[
|
72 |
-
SingleStreamBlock(
|
73 |
-
self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio
|
74 |
-
)
|
75 |
-
for _ in range(params.depth_single_blocks)
|
76 |
-
]
|
77 |
-
)
|
78 |
-
|
79 |
-
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
80 |
-
|
81 |
-
self.connector = Qwen2Connector()
|
82 |
-
|
83 |
-
@staticmethod
|
84 |
-
def timestep_embedding(
|
85 |
-
t: Tensor, dim, max_period=10000, time_factor: float = 1000.0
|
86 |
-
):
|
87 |
-
"""
|
88 |
-
Create sinusoidal timestep embeddings.
|
89 |
-
:param t: a 1-D Tensor of N indices, one per batch element.
|
90 |
-
These may be fractional.
|
91 |
-
:param dim: the dimension of the output.
|
92 |
-
:param max_period: controls the minimum frequency of the embeddings.
|
93 |
-
:return: an (N, D) Tensor of positional embeddings.
|
94 |
-
"""
|
95 |
-
t = time_factor * t
|
96 |
-
half = dim // 2
|
97 |
-
freqs = torch.exp(
|
98 |
-
-math.log(max_period)
|
99 |
-
* torch.arange(start=0, end=half, dtype=torch.float32)
|
100 |
-
/ half
|
101 |
-
).to(t.device)
|
102 |
-
|
103 |
-
args = t[:, None].float() * freqs[None]
|
104 |
-
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
105 |
-
if dim % 2:
|
106 |
-
embedding = torch.cat(
|
107 |
-
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
108 |
-
)
|
109 |
-
if torch.is_floating_point(t):
|
110 |
-
embedding = embedding.to(t)
|
111 |
-
return embedding
|
112 |
-
|
113 |
-
def forward(
|
114 |
-
self,
|
115 |
-
img: Tensor,
|
116 |
-
img_ids: Tensor,
|
117 |
-
txt: Tensor,
|
118 |
-
txt_ids: Tensor,
|
119 |
-
timesteps: Tensor,
|
120 |
-
y: Tensor,
|
121 |
-
) -> Tensor:
|
122 |
-
if img.ndim != 3 or txt.ndim != 3:
|
123 |
-
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
124 |
-
|
125 |
-
img = self.img_in(img)
|
126 |
-
vec = self.time_in(self.timestep_embedding(timesteps, 256))
|
127 |
-
|
128 |
-
vec = vec + self.vector_in(y)
|
129 |
-
txt = self.txt_in(txt)
|
130 |
-
|
131 |
-
ids = torch.cat((txt_ids, img_ids), dim=1)
|
132 |
-
pe = self.pe_embedder(ids)
|
133 |
-
|
134 |
-
for block in self.double_blocks:
|
135 |
-
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
136 |
-
|
137 |
-
img = torch.cat((txt, img), 1)
|
138 |
-
for block in self.single_blocks:
|
139 |
-
img = block(img, vec=vec, pe=pe)
|
140 |
-
img = img[:, txt.shape[1] :, ...]
|
141 |
-
|
142 |
-
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
143 |
-
return img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
celeb_meme.jpg → modules.zip
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c432d89999f0ae531c09c6ccf1d4a69bf5c2bb878f23411fafdf64b7370c8afe
|
3 |
+
size 45293
|
modules/__init__.py
DELETED
File without changes
|
modules/__pycache__/__init__.cpython-310.pyc
DELETED
Binary file (128 Bytes)
|
|
modules/__pycache__/attention.cpython-310.pyc
DELETED
Binary file (3.13 kB)
|
|
modules/__pycache__/autoencoder.cpython-310.pyc
DELETED
Binary file (8.78 kB)
|
|
modules/__pycache__/conditioner.cpython-310.pyc
DELETED
Binary file (4.94 kB)
|
|
modules/__pycache__/connector_edit.cpython-310.pyc
DELETED
Binary file (11.8 kB)
|
|
modules/__pycache__/layers.cpython-310.pyc
DELETED
Binary file (19.1 kB)
|
|
modules/__pycache__/model_edit.cpython-310.pyc
DELETED
Binary file (4.21 kB)
|
|
modules/attention.py
DELETED
@@ -1,133 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import torch.nn.functional as F
|
5 |
-
|
6 |
-
|
7 |
-
try:
|
8 |
-
import flash_attn
|
9 |
-
from flash_attn.flash_attn_interface import (
|
10 |
-
_flash_attn_forward,
|
11 |
-
flash_attn_func,
|
12 |
-
flash_attn_varlen_func,
|
13 |
-
)
|
14 |
-
except ImportError:
|
15 |
-
flash_attn = None
|
16 |
-
flash_attn_varlen_func = None
|
17 |
-
_flash_attn_forward = None
|
18 |
-
flash_attn_func = None
|
19 |
-
|
20 |
-
MEMORY_LAYOUT = {
|
21 |
-
# flash模式:
|
22 |
-
# 预处理: 输入 [batch_size, seq_len, num_heads, head_dim]
|
23 |
-
# 后处理: 保持形状不变
|
24 |
-
"flash": (
|
25 |
-
lambda x: x, # 保持形状
|
26 |
-
lambda x: x, # 保持形状
|
27 |
-
),
|
28 |
-
# torch/vanilla模式:
|
29 |
-
# 预处理: 交换序列和注意力头的维度 [B,S,A,D] -> [B,A,S,D]
|
30 |
-
# 后处理: 交换回原始维度 [B,A,S,D] -> [B,S,A,D]
|
31 |
-
"torch": (
|
32 |
-
lambda x: x.transpose(1, 2), # (B,S,A,D) -> (B,A,S,D)
|
33 |
-
lambda x: x.transpose(1, 2), # (B,A,S,D) -> (B,S,A,D)
|
34 |
-
),
|
35 |
-
"vanilla": (
|
36 |
-
lambda x: x.transpose(1, 2),
|
37 |
-
lambda x: x.transpose(1, 2),
|
38 |
-
),
|
39 |
-
}
|
40 |
-
|
41 |
-
|
42 |
-
def attention(
|
43 |
-
q,
|
44 |
-
k,
|
45 |
-
v,
|
46 |
-
mode="torch",
|
47 |
-
drop_rate=0,
|
48 |
-
attn_mask=None,
|
49 |
-
causal=False,
|
50 |
-
):
|
51 |
-
"""
|
52 |
-
执行QKV自注意力计算
|
53 |
-
|
54 |
-
Args:
|
55 |
-
q (torch.Tensor): 查询张量,形状 [batch_size, seq_len, num_heads, head_dim]
|
56 |
-
k (torch.Tensor): 键张量,形状 [batch_size, seq_len_kv, num_heads, head_dim]
|
57 |
-
v (torch.Tensor): 值张量,形状 [batch_size, seq_len_kv, num_heads, head_dim]
|
58 |
-
mode (str): 注意力模式,可选 'flash', 'torch', 'vanilla'
|
59 |
-
drop_rate (float): 注意力矩阵的dropout概率
|
60 |
-
attn_mask (torch.Tensor): 注意力掩码,形状根据模式不同而变化
|
61 |
-
causal (bool): 是否使用因果注意力(仅关注前面位置)
|
62 |
-
|
63 |
-
Returns:
|
64 |
-
torch.Tensor: 注意力输出,形状 [batch_size, seq_len, num_heads * head_dim]
|
65 |
-
"""
|
66 |
-
# 获取预处理和后处理函数
|
67 |
-
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
|
68 |
-
|
69 |
-
# 应用预处理变换
|
70 |
-
q = pre_attn_layout(q) # 形状根据模式变化
|
71 |
-
k = pre_attn_layout(k)
|
72 |
-
v = pre_attn_layout(v)
|
73 |
-
|
74 |
-
if mode == "torch":
|
75 |
-
# 使用PyTorch原生的scaled_dot_product_attention
|
76 |
-
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
77 |
-
attn_mask = attn_mask.to(q.dtype)
|
78 |
-
x = F.scaled_dot_product_attention(
|
79 |
-
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
|
80 |
-
)
|
81 |
-
elif mode == "flash":
|
82 |
-
assert flash_attn_func is not None, "flash_attn_func未定义"
|
83 |
-
assert attn_mask is None, "不支持的注意力掩码"
|
84 |
-
x: torch.Tensor = flash_attn_func(
|
85 |
-
q, k, v, dropout_p=drop_rate, causal=causal, softmax_scale=None
|
86 |
-
) # type: ignore
|
87 |
-
elif mode == "vanilla":
|
88 |
-
# 手动实现注意力机制
|
89 |
-
scale_factor = 1 / math.sqrt(q.size(-1)) # 缩放因子 1/sqrt(d_k)
|
90 |
-
|
91 |
-
b, a, s, _ = q.shape # 获取形状参数
|
92 |
-
s1 = k.size(2) # 键值序列长度
|
93 |
-
|
94 |
-
# 初始化注意力偏置
|
95 |
-
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
|
96 |
-
|
97 |
-
# 处理因果掩码
|
98 |
-
if causal:
|
99 |
-
assert attn_mask is None, "因果掩码和注意力掩码不能同时使用"
|
100 |
-
# 生成下三角因果掩码
|
101 |
-
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
|
102 |
-
diagonal=0
|
103 |
-
)
|
104 |
-
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
105 |
-
attn_bias = attn_bias.to(q.dtype)
|
106 |
-
|
107 |
-
# 处理自定义注意力掩码
|
108 |
-
if attn_mask is not None:
|
109 |
-
if attn_mask.dtype == torch.bool:
|
110 |
-
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
111 |
-
else:
|
112 |
-
attn_bias += attn_mask # 允许类似ALiBi的位置偏置
|
113 |
-
|
114 |
-
# 计算注意力矩阵
|
115 |
-
attn = (q @ k.transpose(-2, -1)) * scale_factor # [B,A,S,S1]
|
116 |
-
attn += attn_bias
|
117 |
-
|
118 |
-
# softmax和dropout
|
119 |
-
attn = attn.softmax(dim=-1)
|
120 |
-
attn = torch.dropout(attn, p=drop_rate, train=True)
|
121 |
-
|
122 |
-
# 计算输出
|
123 |
-
x = attn @ v # [B,A,S,D]
|
124 |
-
else:
|
125 |
-
raise NotImplementedError(f"不支持的注意力模式: {mode}")
|
126 |
-
|
127 |
-
# 应用后处理变换
|
128 |
-
x = post_attn_layout(x) # 恢复原始维度顺序
|
129 |
-
|
130 |
-
# 合并注意力头维度
|
131 |
-
b, s, a, d = x.shape
|
132 |
-
out = x.reshape(b, s, -1) # [B,S,A*D]
|
133 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/autoencoder.py
DELETED
@@ -1,326 +0,0 @@
|
|
1 |
-
# Modified from Flux
|
2 |
-
#
|
3 |
-
# Copyright 2024 Black Forest Labs
|
4 |
-
|
5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
-
# you may not use this file except in compliance with the License.
|
7 |
-
# You may obtain a copy of the License at
|
8 |
-
|
9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
-
|
11 |
-
# Unless required by applicable law or agreed to in writing, software
|
12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
-
# See the License for the specific language governing permissions and
|
15 |
-
# limitations under the License.
|
16 |
-
#
|
17 |
-
# This source code is licensed under the license found in the
|
18 |
-
# LICENSE file in the root directory of this source tree.
|
19 |
-
import torch
|
20 |
-
from einops import rearrange
|
21 |
-
from torch import Tensor, nn
|
22 |
-
|
23 |
-
|
24 |
-
def swish(x: Tensor) -> Tensor:
|
25 |
-
return x * torch.sigmoid(x)
|
26 |
-
|
27 |
-
|
28 |
-
class AttnBlock(nn.Module):
|
29 |
-
def __init__(self, in_channels: int):
|
30 |
-
super().__init__()
|
31 |
-
self.in_channels = in_channels
|
32 |
-
|
33 |
-
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
34 |
-
|
35 |
-
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
36 |
-
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
37 |
-
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
38 |
-
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
39 |
-
|
40 |
-
def attention(self, h_: Tensor) -> Tensor:
|
41 |
-
h_ = self.norm(h_)
|
42 |
-
q = self.q(h_)
|
43 |
-
k = self.k(h_)
|
44 |
-
v = self.v(h_)
|
45 |
-
|
46 |
-
b, c, h, w = q.shape
|
47 |
-
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
48 |
-
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
49 |
-
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
50 |
-
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
51 |
-
|
52 |
-
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
53 |
-
|
54 |
-
def forward(self, x: Tensor) -> Tensor:
|
55 |
-
return x + self.proj_out(self.attention(x))
|
56 |
-
|
57 |
-
|
58 |
-
class ResnetBlock(nn.Module):
|
59 |
-
def __init__(self, in_channels: int, out_channels: int):
|
60 |
-
super().__init__()
|
61 |
-
self.in_channels = in_channels
|
62 |
-
out_channels = in_channels if out_channels is None else out_channels
|
63 |
-
self.out_channels = out_channels
|
64 |
-
|
65 |
-
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
66 |
-
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
67 |
-
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
68 |
-
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
69 |
-
if self.in_channels != self.out_channels:
|
70 |
-
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
71 |
-
|
72 |
-
def forward(self, x):
|
73 |
-
h = x
|
74 |
-
h = self.norm1(h)
|
75 |
-
h = swish(h)
|
76 |
-
h = self.conv1(h)
|
77 |
-
|
78 |
-
h = self.norm2(h)
|
79 |
-
h = swish(h)
|
80 |
-
h = self.conv2(h)
|
81 |
-
|
82 |
-
if self.in_channels != self.out_channels:
|
83 |
-
x = self.nin_shortcut(x)
|
84 |
-
|
85 |
-
return x + h
|
86 |
-
|
87 |
-
|
88 |
-
class Downsample(nn.Module):
|
89 |
-
def __init__(self, in_channels: int):
|
90 |
-
super().__init__()
|
91 |
-
# no asymmetric padding in torch conv, must do it ourselves
|
92 |
-
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
93 |
-
|
94 |
-
def forward(self, x: Tensor):
|
95 |
-
pad = (0, 1, 0, 1)
|
96 |
-
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
97 |
-
x = self.conv(x)
|
98 |
-
return x
|
99 |
-
|
100 |
-
|
101 |
-
class Upsample(nn.Module):
|
102 |
-
def __init__(self, in_channels: int):
|
103 |
-
super().__init__()
|
104 |
-
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
105 |
-
|
106 |
-
def forward(self, x: Tensor):
|
107 |
-
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
108 |
-
x = self.conv(x)
|
109 |
-
return x
|
110 |
-
|
111 |
-
|
112 |
-
class Encoder(nn.Module):
|
113 |
-
def __init__(
|
114 |
-
self,
|
115 |
-
resolution: int,
|
116 |
-
in_channels: int,
|
117 |
-
ch: int,
|
118 |
-
ch_mult: list[int],
|
119 |
-
num_res_blocks: int,
|
120 |
-
z_channels: int,
|
121 |
-
):
|
122 |
-
super().__init__()
|
123 |
-
self.ch = ch
|
124 |
-
self.num_resolutions = len(ch_mult)
|
125 |
-
self.num_res_blocks = num_res_blocks
|
126 |
-
self.resolution = resolution
|
127 |
-
self.in_channels = in_channels
|
128 |
-
# downsampling
|
129 |
-
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
130 |
-
|
131 |
-
curr_res = resolution
|
132 |
-
in_ch_mult = (1, *tuple(ch_mult))
|
133 |
-
self.in_ch_mult = in_ch_mult
|
134 |
-
self.down = nn.ModuleList()
|
135 |
-
block_in = self.ch
|
136 |
-
for i_level in range(self.num_resolutions):
|
137 |
-
block = nn.ModuleList()
|
138 |
-
attn = nn.ModuleList()
|
139 |
-
block_in = ch * in_ch_mult[i_level]
|
140 |
-
block_out = ch * ch_mult[i_level]
|
141 |
-
for _ in range(self.num_res_blocks):
|
142 |
-
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
143 |
-
block_in = block_out
|
144 |
-
down = nn.Module()
|
145 |
-
down.block = block
|
146 |
-
down.attn = attn
|
147 |
-
if i_level != self.num_resolutions - 1:
|
148 |
-
down.downsample = Downsample(block_in)
|
149 |
-
curr_res = curr_res // 2
|
150 |
-
self.down.append(down)
|
151 |
-
|
152 |
-
# middle
|
153 |
-
self.mid = nn.Module()
|
154 |
-
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
155 |
-
self.mid.attn_1 = AttnBlock(block_in)
|
156 |
-
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
157 |
-
|
158 |
-
# end
|
159 |
-
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
160 |
-
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
161 |
-
|
162 |
-
def forward(self, x: Tensor) -> Tensor:
|
163 |
-
# downsampling
|
164 |
-
hs = [self.conv_in(x)]
|
165 |
-
for i_level in range(self.num_resolutions):
|
166 |
-
for i_block in range(self.num_res_blocks):
|
167 |
-
h = self.down[i_level].block[i_block](hs[-1])
|
168 |
-
if len(self.down[i_level].attn) > 0:
|
169 |
-
h = self.down[i_level].attn[i_block](h)
|
170 |
-
hs.append(h)
|
171 |
-
if i_level != self.num_resolutions - 1:
|
172 |
-
hs.append(self.down[i_level].downsample(hs[-1]))
|
173 |
-
|
174 |
-
# middle
|
175 |
-
h = hs[-1]
|
176 |
-
h = self.mid.block_1(h)
|
177 |
-
h = self.mid.attn_1(h)
|
178 |
-
h = self.mid.block_2(h)
|
179 |
-
# end
|
180 |
-
h = self.norm_out(h)
|
181 |
-
h = swish(h)
|
182 |
-
h = self.conv_out(h)
|
183 |
-
return h
|
184 |
-
|
185 |
-
|
186 |
-
class Decoder(nn.Module):
|
187 |
-
def __init__(
|
188 |
-
self,
|
189 |
-
ch: int,
|
190 |
-
out_ch: int,
|
191 |
-
ch_mult: list[int],
|
192 |
-
num_res_blocks: int,
|
193 |
-
in_channels: int,
|
194 |
-
resolution: int,
|
195 |
-
z_channels: int,
|
196 |
-
):
|
197 |
-
super().__init__()
|
198 |
-
self.ch = ch
|
199 |
-
self.num_resolutions = len(ch_mult)
|
200 |
-
self.num_res_blocks = num_res_blocks
|
201 |
-
self.resolution = resolution
|
202 |
-
self.in_channels = in_channels
|
203 |
-
self.ffactor = 2 ** (self.num_resolutions - 1)
|
204 |
-
|
205 |
-
# compute in_ch_mult, block_in and curr_res at lowest res
|
206 |
-
block_in = ch * ch_mult[self.num_resolutions - 1]
|
207 |
-
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
208 |
-
self.z_shape = (1, z_channels, curr_res, curr_res)
|
209 |
-
|
210 |
-
# z to block_in
|
211 |
-
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
212 |
-
|
213 |
-
# middle
|
214 |
-
self.mid = nn.Module()
|
215 |
-
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
216 |
-
self.mid.attn_1 = AttnBlock(block_in)
|
217 |
-
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
218 |
-
|
219 |
-
# upsampling
|
220 |
-
self.up = nn.ModuleList()
|
221 |
-
for i_level in reversed(range(self.num_resolutions)):
|
222 |
-
block = nn.ModuleList()
|
223 |
-
attn = nn.ModuleList()
|
224 |
-
block_out = ch * ch_mult[i_level]
|
225 |
-
for _ in range(self.num_res_blocks + 1):
|
226 |
-
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
227 |
-
block_in = block_out
|
228 |
-
up = nn.Module()
|
229 |
-
up.block = block
|
230 |
-
up.attn = attn
|
231 |
-
if i_level != 0:
|
232 |
-
up.upsample = Upsample(block_in)
|
233 |
-
curr_res = curr_res * 2
|
234 |
-
self.up.insert(0, up) # prepend to get consistent order
|
235 |
-
|
236 |
-
# end
|
237 |
-
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
238 |
-
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
239 |
-
|
240 |
-
def forward(self, z: Tensor) -> Tensor:
|
241 |
-
# z to block_in
|
242 |
-
h = self.conv_in(z)
|
243 |
-
|
244 |
-
# middle
|
245 |
-
h = self.mid.block_1(h)
|
246 |
-
h = self.mid.attn_1(h)
|
247 |
-
h = self.mid.block_2(h)
|
248 |
-
|
249 |
-
# upsampling
|
250 |
-
for i_level in reversed(range(self.num_resolutions)):
|
251 |
-
for i_block in range(self.num_res_blocks + 1):
|
252 |
-
h = self.up[i_level].block[i_block](h)
|
253 |
-
if len(self.up[i_level].attn) > 0:
|
254 |
-
h = self.up[i_level].attn[i_block](h)
|
255 |
-
if i_level != 0:
|
256 |
-
h = self.up[i_level].upsample(h)
|
257 |
-
|
258 |
-
# end
|
259 |
-
h = self.norm_out(h)
|
260 |
-
h = swish(h)
|
261 |
-
h = self.conv_out(h)
|
262 |
-
return h
|
263 |
-
|
264 |
-
|
265 |
-
class DiagonalGaussian(nn.Module):
|
266 |
-
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
267 |
-
super().__init__()
|
268 |
-
self.sample = sample
|
269 |
-
self.chunk_dim = chunk_dim
|
270 |
-
|
271 |
-
def forward(self, z: Tensor) -> Tensor:
|
272 |
-
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
273 |
-
if self.sample:
|
274 |
-
std = torch.exp(0.5 * logvar)
|
275 |
-
return mean + std * torch.randn_like(mean)
|
276 |
-
else:
|
277 |
-
return mean
|
278 |
-
|
279 |
-
|
280 |
-
class AutoEncoder(nn.Module):
|
281 |
-
def __init__(
|
282 |
-
self,
|
283 |
-
resolution: int,
|
284 |
-
in_channels: int,
|
285 |
-
ch: int,
|
286 |
-
out_ch: int,
|
287 |
-
ch_mult: list[int],
|
288 |
-
num_res_blocks: int,
|
289 |
-
z_channels: int,
|
290 |
-
scale_factor: float,
|
291 |
-
shift_factor: float,
|
292 |
-
):
|
293 |
-
super().__init__()
|
294 |
-
self.encoder = Encoder(
|
295 |
-
resolution=resolution,
|
296 |
-
in_channels=in_channels,
|
297 |
-
ch=ch,
|
298 |
-
ch_mult=ch_mult,
|
299 |
-
num_res_blocks=num_res_blocks,
|
300 |
-
z_channels=z_channels,
|
301 |
-
)
|
302 |
-
self.decoder = Decoder(
|
303 |
-
resolution=resolution,
|
304 |
-
in_channels=in_channels,
|
305 |
-
ch=ch,
|
306 |
-
out_ch=out_ch,
|
307 |
-
ch_mult=ch_mult,
|
308 |
-
num_res_blocks=num_res_blocks,
|
309 |
-
z_channels=z_channels,
|
310 |
-
)
|
311 |
-
self.reg = DiagonalGaussian()
|
312 |
-
|
313 |
-
self.scale_factor = scale_factor
|
314 |
-
self.shift_factor = shift_factor
|
315 |
-
|
316 |
-
def encode(self, x: Tensor) -> Tensor:
|
317 |
-
z = self.reg(self.encoder(x))
|
318 |
-
z = self.scale_factor * (z - self.shift_factor)
|
319 |
-
return z
|
320 |
-
|
321 |
-
def decode(self, z: Tensor) -> Tensor:
|
322 |
-
z = z / self.scale_factor + self.shift_factor
|
323 |
-
return self.decoder(z)
|
324 |
-
|
325 |
-
def forward(self, x: Tensor) -> Tensor:
|
326 |
-
return self.decode(self.encode(x))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/conditioner.py
DELETED
@@ -1,216 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from qwen_vl_utils import process_vision_info
|
3 |
-
from transformers import (
|
4 |
-
AutoProcessor,
|
5 |
-
Qwen2VLForConditionalGeneration,
|
6 |
-
Qwen2_5_VLForConditionalGeneration,
|
7 |
-
)
|
8 |
-
from torchvision.transforms import ToPILImage
|
9 |
-
|
10 |
-
to_pil = ToPILImage()
|
11 |
-
|
12 |
-
Qwen25VL_7b_PREFIX = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:
|
13 |
-
- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.
|
14 |
-
- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n
|
15 |
-
Here are examples of how to transform or refine prompts:
|
16 |
-
- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.
|
17 |
-
- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n
|
18 |
-
Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:
|
19 |
-
User Prompt:'''
|
20 |
-
|
21 |
-
|
22 |
-
def split_string(s):
|
23 |
-
# 将中文引号替换为英文引号
|
24 |
-
s = s.replace("“", '"').replace("”", '"') # use english quotes
|
25 |
-
result = []
|
26 |
-
# 标记是否在引号内
|
27 |
-
in_quotes = False
|
28 |
-
temp = ""
|
29 |
-
|
30 |
-
# 遍历字符串中的每个字符及其索引
|
31 |
-
for idx, char in enumerate(s):
|
32 |
-
# 如果字符是引号且索引大于 155
|
33 |
-
if char == '"' and idx > 155:
|
34 |
-
# 将引号添加到临时字符串
|
35 |
-
temp += char
|
36 |
-
# 如果不在引号内
|
37 |
-
if not in_quotes:
|
38 |
-
# 将临时字符串添加到结果列表
|
39 |
-
result.append(temp)
|
40 |
-
# 清空临时字符串
|
41 |
-
temp = ""
|
42 |
-
|
43 |
-
# 切换引号状态
|
44 |
-
in_quotes = not in_quotes
|
45 |
-
continue
|
46 |
-
# 如果在引号内
|
47 |
-
if in_quotes:
|
48 |
-
# 如果字符是空格
|
49 |
-
if char.isspace():
|
50 |
-
pass # have space token
|
51 |
-
|
52 |
-
# 将字符用中文引号包裹后添加到结果列表
|
53 |
-
result.append("“" + char + "”")
|
54 |
-
else:
|
55 |
-
# 将字符添加到临时字符串
|
56 |
-
temp += char
|
57 |
-
|
58 |
-
# 如果临时字符串不为空
|
59 |
-
if temp:
|
60 |
-
# 将临时字符串添加到结果列表
|
61 |
-
result.append(temp)
|
62 |
-
|
63 |
-
return result
|
64 |
-
|
65 |
-
|
66 |
-
class Qwen25VL_7b_Embedder(torch.nn.Module):
|
67 |
-
def __init__(self, model_path, max_length=640, dtype=torch.bfloat16, device="cuda"):
|
68 |
-
super(Qwen25VL_7b_Embedder, self).__init__()
|
69 |
-
self.max_length = max_length
|
70 |
-
self.dtype = dtype
|
71 |
-
self.device = device
|
72 |
-
|
73 |
-
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
74 |
-
model_path,
|
75 |
-
torch_dtype=dtype,
|
76 |
-
attn_implementation="eager",
|
77 |
-
).to(torch.cuda.current_device())
|
78 |
-
|
79 |
-
self.model.requires_grad_(False)
|
80 |
-
self.processor = AutoProcessor.from_pretrained(
|
81 |
-
model_path, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28
|
82 |
-
)
|
83 |
-
|
84 |
-
self.prefix = Qwen25VL_7b_PREFIX
|
85 |
-
|
86 |
-
def forward(self, caption, ref_images):
|
87 |
-
text_list = caption
|
88 |
-
embs = torch.zeros(
|
89 |
-
len(text_list),
|
90 |
-
self.max_length,
|
91 |
-
self.model.config.hidden_size,
|
92 |
-
dtype=torch.bfloat16,
|
93 |
-
device=torch.cuda.current_device(),
|
94 |
-
)
|
95 |
-
hidden_states = torch.zeros(
|
96 |
-
len(text_list),
|
97 |
-
self.max_length,
|
98 |
-
self.model.config.hidden_size,
|
99 |
-
dtype=torch.bfloat16,
|
100 |
-
device=torch.cuda.current_device(),
|
101 |
-
)
|
102 |
-
masks = torch.zeros(
|
103 |
-
len(text_list),
|
104 |
-
self.max_length,
|
105 |
-
dtype=torch.long,
|
106 |
-
device=torch.cuda.current_device(),
|
107 |
-
)
|
108 |
-
input_ids_list = []
|
109 |
-
attention_mask_list = []
|
110 |
-
emb_list = []
|
111 |
-
|
112 |
-
def split_string(s):
|
113 |
-
s = s.replace("“", '"').replace("”", '"').replace("'", '''"''') # use english quotes
|
114 |
-
result = []
|
115 |
-
in_quotes = False
|
116 |
-
temp = ""
|
117 |
-
|
118 |
-
for idx,char in enumerate(s):
|
119 |
-
if char == '"' and idx>155:
|
120 |
-
temp += char
|
121 |
-
if not in_quotes:
|
122 |
-
result.append(temp)
|
123 |
-
temp = ""
|
124 |
-
|
125 |
-
in_quotes = not in_quotes
|
126 |
-
continue
|
127 |
-
if in_quotes:
|
128 |
-
if char.isspace():
|
129 |
-
pass # have space token
|
130 |
-
|
131 |
-
result.append("“" + char + "”")
|
132 |
-
else:
|
133 |
-
temp += char
|
134 |
-
|
135 |
-
if temp:
|
136 |
-
result.append(temp)
|
137 |
-
|
138 |
-
return result
|
139 |
-
|
140 |
-
for idx, (txt, imgs) in enumerate(zip(text_list, ref_images)):
|
141 |
-
|
142 |
-
messages = [{"role": "user", "content": []}]
|
143 |
-
|
144 |
-
messages[0]["content"].append({"type": "text", "text": f"{self.prefix}"})
|
145 |
-
|
146 |
-
messages[0]["content"].append({"type": "image", "image": to_pil(imgs)})
|
147 |
-
|
148 |
-
# 再添加 text
|
149 |
-
messages[0]["content"].append({"type": "text", "text": f"{txt}"})
|
150 |
-
|
151 |
-
# Preparation for inference
|
152 |
-
text = self.processor.apply_chat_template(
|
153 |
-
messages, tokenize=False, add_generation_prompt=True, add_vision_id=True
|
154 |
-
)
|
155 |
-
|
156 |
-
image_inputs, video_inputs = process_vision_info(messages)
|
157 |
-
|
158 |
-
inputs = self.processor(
|
159 |
-
text=[text],
|
160 |
-
images=image_inputs,
|
161 |
-
padding=True,
|
162 |
-
return_tensors="pt",
|
163 |
-
)
|
164 |
-
|
165 |
-
old_inputs_ids = inputs.input_ids
|
166 |
-
text_split_list = split_string(text)
|
167 |
-
|
168 |
-
token_list = []
|
169 |
-
for text_each in text_split_list:
|
170 |
-
txt_inputs = self.processor(
|
171 |
-
text=text_each,
|
172 |
-
images=None,
|
173 |
-
videos=None,
|
174 |
-
padding=True,
|
175 |
-
return_tensors="pt",
|
176 |
-
)
|
177 |
-
token_each = txt_inputs.input_ids
|
178 |
-
if token_each[0][0] == 2073 and token_each[0][-1] == 854:
|
179 |
-
token_each = token_each[:, 1:-1]
|
180 |
-
token_list.append(token_each)
|
181 |
-
else:
|
182 |
-
token_list.append(token_each)
|
183 |
-
|
184 |
-
new_txt_ids = torch.cat(token_list, dim=1).to("cuda")
|
185 |
-
|
186 |
-
new_txt_ids = new_txt_ids.to(old_inputs_ids.device)
|
187 |
-
|
188 |
-
idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0]
|
189 |
-
idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0]
|
190 |
-
inputs.input_ids = (
|
191 |
-
torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
|
192 |
-
.unsqueeze(0)
|
193 |
-
.to("cuda")
|
194 |
-
)
|
195 |
-
inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda")
|
196 |
-
outputs = self.model(
|
197 |
-
input_ids=inputs.input_ids,
|
198 |
-
attention_mask=inputs.attention_mask,
|
199 |
-
pixel_values=inputs.pixel_values.to("cuda"),
|
200 |
-
image_grid_thw=inputs.image_grid_thw.to("cuda"),
|
201 |
-
output_hidden_states=True,
|
202 |
-
)
|
203 |
-
|
204 |
-
emb = outputs["hidden_states"][-1]
|
205 |
-
|
206 |
-
embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][
|
207 |
-
: self.max_length
|
208 |
-
]
|
209 |
-
|
210 |
-
masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
|
211 |
-
(min(self.max_length, emb.shape[1] - 217)),
|
212 |
-
dtype=torch.long,
|
213 |
-
device=torch.cuda.current_device(),
|
214 |
-
)
|
215 |
-
|
216 |
-
return embs, masks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/connector_edit.py
DELETED
@@ -1,486 +0,0 @@
|
|
1 |
-
from typing import Optional
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import torch.nn
|
5 |
-
from einops import rearrange
|
6 |
-
from torch import nn
|
7 |
-
|
8 |
-
from .layers import MLP, TextProjection, TimestepEmbedder, apply_gate, attention
|
9 |
-
|
10 |
-
|
11 |
-
class RMSNorm(nn.Module):
|
12 |
-
def __init__(
|
13 |
-
self,
|
14 |
-
dim: int,
|
15 |
-
elementwise_affine=True,
|
16 |
-
eps: float = 1e-6,
|
17 |
-
device=None,
|
18 |
-
dtype=None,
|
19 |
-
):
|
20 |
-
"""
|
21 |
-
Initialize the RMSNorm normalization layer.
|
22 |
-
|
23 |
-
Args:
|
24 |
-
dim (int): The dimension of the input tensor.
|
25 |
-
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
26 |
-
|
27 |
-
Attributes:
|
28 |
-
eps (float): A small value added to the denominator for numerical stability.
|
29 |
-
weight (nn.Parameter): Learnable scaling parameter.
|
30 |
-
|
31 |
-
"""
|
32 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
33 |
-
super().__init__()
|
34 |
-
self.eps = eps
|
35 |
-
if elementwise_affine:
|
36 |
-
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
37 |
-
|
38 |
-
def _norm(self, x):
|
39 |
-
"""
|
40 |
-
Apply the RMSNorm normalization to the input tensor.
|
41 |
-
|
42 |
-
Args:
|
43 |
-
x (torch.Tensor): The input tensor.
|
44 |
-
|
45 |
-
Returns:
|
46 |
-
torch.Tensor: The normalized tensor.
|
47 |
-
|
48 |
-
"""
|
49 |
-
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
50 |
-
|
51 |
-
def forward(self, x):
|
52 |
-
"""
|
53 |
-
Forward pass through the RMSNorm layer.
|
54 |
-
|
55 |
-
Args:
|
56 |
-
x (torch.Tensor): The input tensor.
|
57 |
-
|
58 |
-
Returns:
|
59 |
-
torch.Tensor: The output tensor after applying RMSNorm.
|
60 |
-
|
61 |
-
"""
|
62 |
-
output = self._norm(x.float()).type_as(x)
|
63 |
-
if hasattr(self, "weight"):
|
64 |
-
output = output * self.weight
|
65 |
-
return output
|
66 |
-
|
67 |
-
|
68 |
-
def get_norm_layer(norm_layer):
|
69 |
-
"""
|
70 |
-
Get the normalization layer.
|
71 |
-
|
72 |
-
Args:
|
73 |
-
norm_layer (str): The type of normalization layer.
|
74 |
-
|
75 |
-
Returns:
|
76 |
-
norm_layer (nn.Module): The normalization layer.
|
77 |
-
"""
|
78 |
-
if norm_layer == "layer":
|
79 |
-
return nn.LayerNorm
|
80 |
-
elif norm_layer == "rms":
|
81 |
-
return RMSNorm
|
82 |
-
else:
|
83 |
-
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
|
84 |
-
|
85 |
-
|
86 |
-
def get_activation_layer(act_type):
|
87 |
-
"""get activation layer
|
88 |
-
|
89 |
-
Args:
|
90 |
-
act_type (str): the activation type
|
91 |
-
|
92 |
-
Returns:
|
93 |
-
torch.nn.functional: the activation layer
|
94 |
-
"""
|
95 |
-
if act_type == "gelu":
|
96 |
-
return lambda: nn.GELU()
|
97 |
-
elif act_type == "gelu_tanh":
|
98 |
-
return lambda: nn.GELU(approximate="tanh")
|
99 |
-
elif act_type == "relu":
|
100 |
-
return nn.ReLU
|
101 |
-
elif act_type == "silu":
|
102 |
-
return nn.SiLU
|
103 |
-
else:
|
104 |
-
raise ValueError(f"Unknown activation type: {act_type}")
|
105 |
-
|
106 |
-
class IndividualTokenRefinerBlock(torch.nn.Module):
|
107 |
-
def __init__(
|
108 |
-
self,
|
109 |
-
hidden_size,
|
110 |
-
heads_num,
|
111 |
-
mlp_width_ratio: str = 4.0,
|
112 |
-
mlp_drop_rate: float = 0.0,
|
113 |
-
act_type: str = "silu",
|
114 |
-
qk_norm: bool = False,
|
115 |
-
qk_norm_type: str = "layer",
|
116 |
-
qkv_bias: bool = True,
|
117 |
-
need_CA: bool = False,
|
118 |
-
dtype: Optional[torch.dtype] = None,
|
119 |
-
device: Optional[torch.device] = None,
|
120 |
-
):
|
121 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
122 |
-
super().__init__()
|
123 |
-
self.need_CA = need_CA
|
124 |
-
self.heads_num = heads_num
|
125 |
-
head_dim = hidden_size // heads_num
|
126 |
-
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
127 |
-
|
128 |
-
self.norm1 = nn.LayerNorm(
|
129 |
-
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
130 |
-
)
|
131 |
-
self.self_attn_qkv = nn.Linear(
|
132 |
-
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
|
133 |
-
)
|
134 |
-
qk_norm_layer = get_norm_layer(qk_norm_type)
|
135 |
-
self.self_attn_q_norm = (
|
136 |
-
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
137 |
-
if qk_norm
|
138 |
-
else nn.Identity()
|
139 |
-
)
|
140 |
-
self.self_attn_k_norm = (
|
141 |
-
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
142 |
-
if qk_norm
|
143 |
-
else nn.Identity()
|
144 |
-
)
|
145 |
-
self.self_attn_proj = nn.Linear(
|
146 |
-
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
147 |
-
)
|
148 |
-
|
149 |
-
self.norm2 = nn.LayerNorm(
|
150 |
-
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
151 |
-
)
|
152 |
-
act_layer = get_activation_layer(act_type)
|
153 |
-
self.mlp = MLP(
|
154 |
-
in_channels=hidden_size,
|
155 |
-
hidden_channels=mlp_hidden_dim,
|
156 |
-
act_layer=act_layer,
|
157 |
-
drop=mlp_drop_rate,
|
158 |
-
**factory_kwargs,
|
159 |
-
)
|
160 |
-
|
161 |
-
self.adaLN_modulation = nn.Sequential(
|
162 |
-
act_layer(),
|
163 |
-
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
|
164 |
-
)
|
165 |
-
|
166 |
-
if self.need_CA:
|
167 |
-
self.cross_attnblock=CrossAttnBlock(hidden_size=hidden_size,
|
168 |
-
heads_num=heads_num,
|
169 |
-
mlp_width_ratio=mlp_width_ratio,
|
170 |
-
mlp_drop_rate=mlp_drop_rate,
|
171 |
-
act_type=act_type,
|
172 |
-
qk_norm=qk_norm,
|
173 |
-
qk_norm_type=qk_norm_type,
|
174 |
-
qkv_bias=qkv_bias,
|
175 |
-
**factory_kwargs,)
|
176 |
-
# Zero-initialize the modulation
|
177 |
-
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
178 |
-
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
179 |
-
|
180 |
-
def forward(
|
181 |
-
self,
|
182 |
-
x: torch.Tensor,
|
183 |
-
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
|
184 |
-
attn_mask: torch.Tensor = None,
|
185 |
-
y: torch.Tensor = None,
|
186 |
-
):
|
187 |
-
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
188 |
-
|
189 |
-
norm_x = self.norm1(x)
|
190 |
-
qkv = self.self_attn_qkv(norm_x)
|
191 |
-
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
192 |
-
# Apply QK-Norm if needed
|
193 |
-
q = self.self_attn_q_norm(q).to(v)
|
194 |
-
k = self.self_attn_k_norm(k).to(v)
|
195 |
-
|
196 |
-
# Self-Attention
|
197 |
-
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
|
198 |
-
|
199 |
-
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
|
200 |
-
|
201 |
-
if self.need_CA:
|
202 |
-
x = self.cross_attnblock(x, c, attn_mask, y)
|
203 |
-
|
204 |
-
# FFN Layer
|
205 |
-
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
|
206 |
-
|
207 |
-
return x
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
class CrossAttnBlock(torch.nn.Module):
|
213 |
-
def __init__(
|
214 |
-
self,
|
215 |
-
hidden_size,
|
216 |
-
heads_num,
|
217 |
-
mlp_width_ratio: str = 4.0,
|
218 |
-
mlp_drop_rate: float = 0.0,
|
219 |
-
act_type: str = "silu",
|
220 |
-
qk_norm: bool = False,
|
221 |
-
qk_norm_type: str = "layer",
|
222 |
-
qkv_bias: bool = True,
|
223 |
-
dtype: Optional[torch.dtype] = None,
|
224 |
-
device: Optional[torch.device] = None,
|
225 |
-
):
|
226 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
227 |
-
super().__init__()
|
228 |
-
self.heads_num = heads_num
|
229 |
-
head_dim = hidden_size // heads_num
|
230 |
-
|
231 |
-
self.norm1 = nn.LayerNorm(
|
232 |
-
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
233 |
-
)
|
234 |
-
self.norm1_2 = nn.LayerNorm(
|
235 |
-
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
236 |
-
)
|
237 |
-
self.self_attn_q = nn.Linear(
|
238 |
-
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
239 |
-
)
|
240 |
-
self.self_attn_kv = nn.Linear(
|
241 |
-
hidden_size, hidden_size*2, bias=qkv_bias, **factory_kwargs
|
242 |
-
)
|
243 |
-
qk_norm_layer = get_norm_layer(qk_norm_type)
|
244 |
-
self.self_attn_q_norm = (
|
245 |
-
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
246 |
-
if qk_norm
|
247 |
-
else nn.Identity()
|
248 |
-
)
|
249 |
-
self.self_attn_k_norm = (
|
250 |
-
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
251 |
-
if qk_norm
|
252 |
-
else nn.Identity()
|
253 |
-
)
|
254 |
-
self.self_attn_proj = nn.Linear(
|
255 |
-
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
256 |
-
)
|
257 |
-
|
258 |
-
self.norm2 = nn.LayerNorm(
|
259 |
-
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
260 |
-
)
|
261 |
-
act_layer = get_activation_layer(act_type)
|
262 |
-
|
263 |
-
self.adaLN_modulation = nn.Sequential(
|
264 |
-
act_layer(),
|
265 |
-
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
|
266 |
-
)
|
267 |
-
# Zero-initialize the modulation
|
268 |
-
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
269 |
-
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
270 |
-
|
271 |
-
def forward(
|
272 |
-
self,
|
273 |
-
x: torch.Tensor,
|
274 |
-
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
|
275 |
-
attn_mask: torch.Tensor = None,
|
276 |
-
y: torch.Tensor=None,
|
277 |
-
|
278 |
-
):
|
279 |
-
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
280 |
-
|
281 |
-
norm_x = self.norm1(x)
|
282 |
-
norm_y = self.norm1_2(y)
|
283 |
-
q = self.self_attn_q(norm_x)
|
284 |
-
q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num)
|
285 |
-
kv = self.self_attn_kv(norm_y)
|
286 |
-
k, v = rearrange(kv, "B L (K H D) -> K B L H D", K=2, H=self.heads_num)
|
287 |
-
# Apply QK-Norm if needed
|
288 |
-
q = self.self_attn_q_norm(q).to(v)
|
289 |
-
k = self.self_attn_k_norm(k).to(v)
|
290 |
-
|
291 |
-
# Self-Attention
|
292 |
-
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
|
293 |
-
|
294 |
-
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
|
295 |
-
|
296 |
-
return x
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
class IndividualTokenRefiner(torch.nn.Module):
|
301 |
-
def __init__(
|
302 |
-
self,
|
303 |
-
hidden_size,
|
304 |
-
heads_num,
|
305 |
-
depth,
|
306 |
-
mlp_width_ratio: float = 4.0,
|
307 |
-
mlp_drop_rate: float = 0.0,
|
308 |
-
act_type: str = "silu",
|
309 |
-
qk_norm: bool = False,
|
310 |
-
qk_norm_type: str = "layer",
|
311 |
-
qkv_bias: bool = True,
|
312 |
-
need_CA:bool=False,
|
313 |
-
dtype: Optional[torch.dtype] = None,
|
314 |
-
device: Optional[torch.device] = None,
|
315 |
-
):
|
316 |
-
|
317 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
318 |
-
super().__init__()
|
319 |
-
self.need_CA = need_CA
|
320 |
-
self.blocks = nn.ModuleList(
|
321 |
-
[
|
322 |
-
IndividualTokenRefinerBlock(
|
323 |
-
hidden_size=hidden_size,
|
324 |
-
heads_num=heads_num,
|
325 |
-
mlp_width_ratio=mlp_width_ratio,
|
326 |
-
mlp_drop_rate=mlp_drop_rate,
|
327 |
-
act_type=act_type,
|
328 |
-
qk_norm=qk_norm,
|
329 |
-
qk_norm_type=qk_norm_type,
|
330 |
-
qkv_bias=qkv_bias,
|
331 |
-
need_CA=self.need_CA,
|
332 |
-
**factory_kwargs,
|
333 |
-
)
|
334 |
-
for _ in range(depth)
|
335 |
-
]
|
336 |
-
)
|
337 |
-
|
338 |
-
|
339 |
-
def forward(
|
340 |
-
self,
|
341 |
-
x: torch.Tensor,
|
342 |
-
c: torch.LongTensor,
|
343 |
-
mask: Optional[torch.Tensor] = None,
|
344 |
-
y:torch.Tensor=None,
|
345 |
-
):
|
346 |
-
self_attn_mask = None
|
347 |
-
if mask is not None:
|
348 |
-
batch_size = mask.shape[0]
|
349 |
-
seq_len = mask.shape[1]
|
350 |
-
mask = mask.to(x.device)
|
351 |
-
# batch_size x 1 x seq_len x seq_len
|
352 |
-
self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
|
353 |
-
1, 1, seq_len, 1
|
354 |
-
)
|
355 |
-
# batch_size x 1 x seq_len x seq_len
|
356 |
-
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
357 |
-
# batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
|
358 |
-
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
|
359 |
-
# avoids self-attention weight being NaN for padding tokens
|
360 |
-
self_attn_mask[:, :, :, 0] = True
|
361 |
-
|
362 |
-
|
363 |
-
for block in self.blocks:
|
364 |
-
x = block(x, c, self_attn_mask,y)
|
365 |
-
|
366 |
-
return x
|
367 |
-
|
368 |
-
|
369 |
-
class SingleTokenRefiner(torch.nn.Module):
|
370 |
-
"""
|
371 |
-
A single token refiner block for llm text embedding refine.
|
372 |
-
"""
|
373 |
-
def __init__(
|
374 |
-
self,
|
375 |
-
in_channels,
|
376 |
-
hidden_size,
|
377 |
-
heads_num,
|
378 |
-
depth,
|
379 |
-
mlp_width_ratio: float = 4.0,
|
380 |
-
mlp_drop_rate: float = 0.0,
|
381 |
-
act_type: str = "silu",
|
382 |
-
qk_norm: bool = False,
|
383 |
-
qk_norm_type: str = "layer",
|
384 |
-
qkv_bias: bool = True,
|
385 |
-
need_CA:bool=False,
|
386 |
-
attn_mode: str = "torch",
|
387 |
-
dtype: Optional[torch.dtype] = None,
|
388 |
-
device: Optional[torch.device] = None,
|
389 |
-
):
|
390 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
391 |
-
super().__init__()
|
392 |
-
self.attn_mode = attn_mode
|
393 |
-
self.need_CA = need_CA
|
394 |
-
assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
|
395 |
-
|
396 |
-
self.input_embedder = nn.Linear(
|
397 |
-
in_channels, hidden_size, bias=True, **factory_kwargs
|
398 |
-
)
|
399 |
-
if self.need_CA:
|
400 |
-
self.input_embedder_CA = nn.Linear(
|
401 |
-
in_channels, hidden_size, bias=True, **factory_kwargs
|
402 |
-
)
|
403 |
-
|
404 |
-
act_layer = get_activation_layer(act_type)
|
405 |
-
# Build timestep embedding layer
|
406 |
-
self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
|
407 |
-
# Build context embedding layer
|
408 |
-
self.c_embedder = TextProjection(
|
409 |
-
in_channels, hidden_size, act_layer, **factory_kwargs
|
410 |
-
)
|
411 |
-
|
412 |
-
self.individual_token_refiner = IndividualTokenRefiner(
|
413 |
-
hidden_size=hidden_size,
|
414 |
-
heads_num=heads_num,
|
415 |
-
depth=depth,
|
416 |
-
mlp_width_ratio=mlp_width_ratio,
|
417 |
-
mlp_drop_rate=mlp_drop_rate,
|
418 |
-
act_type=act_type,
|
419 |
-
qk_norm=qk_norm,
|
420 |
-
qk_norm_type=qk_norm_type,
|
421 |
-
qkv_bias=qkv_bias,
|
422 |
-
need_CA=need_CA,
|
423 |
-
**factory_kwargs,
|
424 |
-
)
|
425 |
-
|
426 |
-
def forward(
|
427 |
-
self,
|
428 |
-
x: torch.Tensor,
|
429 |
-
t: torch.LongTensor,
|
430 |
-
mask: Optional[torch.LongTensor] = None,
|
431 |
-
y: torch.LongTensor=None,
|
432 |
-
):
|
433 |
-
timestep_aware_representations = self.t_embedder(t)
|
434 |
-
|
435 |
-
if mask is None:
|
436 |
-
context_aware_representations = x.mean(dim=1)
|
437 |
-
else:
|
438 |
-
mask_float = mask.unsqueeze(-1) # [b, s1, 1]
|
439 |
-
context_aware_representations = (x * mask_float).sum(
|
440 |
-
dim=1
|
441 |
-
) / mask_float.sum(dim=1)
|
442 |
-
context_aware_representations = self.c_embedder(context_aware_representations)
|
443 |
-
c = timestep_aware_representations + context_aware_representations
|
444 |
-
|
445 |
-
x = self.input_embedder(x)
|
446 |
-
if self.need_CA:
|
447 |
-
y = self.input_embedder_CA(y)
|
448 |
-
x = self.individual_token_refiner(x, c, mask, y)
|
449 |
-
else:
|
450 |
-
x = self.individual_token_refiner(x, c, mask)
|
451 |
-
|
452 |
-
return x
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
class Qwen2Connector(torch.nn.Module):
|
457 |
-
def __init__(
|
458 |
-
self,
|
459 |
-
# biclip_dim=1024,
|
460 |
-
in_channels=3584,
|
461 |
-
hidden_size=4096,
|
462 |
-
heads_num=32,
|
463 |
-
depth=2,
|
464 |
-
need_CA=False,
|
465 |
-
device=None,
|
466 |
-
dtype=torch.bfloat16,
|
467 |
-
):
|
468 |
-
super().__init__()
|
469 |
-
factory_kwargs = {"device": device, "dtype":dtype}
|
470 |
-
|
471 |
-
self.S =SingleTokenRefiner(in_channels=in_channels,hidden_size=hidden_size,heads_num=heads_num,depth=depth,need_CA=need_CA,**factory_kwargs)
|
472 |
-
self.global_proj_out=nn.Linear(in_channels,768)
|
473 |
-
|
474 |
-
self.scale_factor = nn.Parameter(torch.zeros(1))
|
475 |
-
with torch.no_grad():
|
476 |
-
self.scale_factor.data += -(1 - 0.09)
|
477 |
-
|
478 |
-
def forward(self, x,t,mask):
|
479 |
-
mask_float = mask.unsqueeze(-1) # [b, s1, 1]
|
480 |
-
x_mean = (x * mask_float).sum(
|
481 |
-
dim=1
|
482 |
-
) / mask_float.sum(dim=1) * (1 + self.scale_factor)
|
483 |
-
|
484 |
-
global_out=self.global_proj_out(x_mean)
|
485 |
-
encoder_hidden_states = self.S(x,t,mask)
|
486 |
-
return encoder_hidden_states,global_out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/layers.py
DELETED
@@ -1,640 +0,0 @@
|
|
1 |
-
# Modified from Flux
|
2 |
-
#
|
3 |
-
# Copyright 2024 Black Forest Labs
|
4 |
-
|
5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
-
# you may not use this file except in compliance with the License.
|
7 |
-
# You may obtain a copy of the License at
|
8 |
-
|
9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
-
|
11 |
-
# Unless required by applicable law or agreed to in writing, software
|
12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
-
# See the License for the specific language governing permissions and
|
15 |
-
# limitations under the License.
|
16 |
-
#
|
17 |
-
# This source code is licensed under the license found in the
|
18 |
-
# LICENSE file in the root directory of this source tree.
|
19 |
-
|
20 |
-
import math # noqa: I001
|
21 |
-
from dataclasses import dataclass
|
22 |
-
from functools import partial
|
23 |
-
|
24 |
-
import torch
|
25 |
-
import torch.nn.functional as F
|
26 |
-
from einops import rearrange
|
27 |
-
# from liger_kernel.ops.rms_norm import LigerRMSNormFunction
|
28 |
-
from torch import Tensor, nn
|
29 |
-
|
30 |
-
|
31 |
-
try:
|
32 |
-
import flash_attn
|
33 |
-
from flash_attn.flash_attn_interface import (
|
34 |
-
_flash_attn_forward,
|
35 |
-
flash_attn_varlen_func,
|
36 |
-
)
|
37 |
-
except ImportError:
|
38 |
-
flash_attn = None
|
39 |
-
flash_attn_varlen_func = None
|
40 |
-
_flash_attn_forward = None
|
41 |
-
|
42 |
-
|
43 |
-
MEMORY_LAYOUT = {
|
44 |
-
"flash": (
|
45 |
-
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
|
46 |
-
lambda x: x,
|
47 |
-
),
|
48 |
-
"torch": (
|
49 |
-
lambda x: x.transpose(1, 2),
|
50 |
-
lambda x: x.transpose(1, 2),
|
51 |
-
),
|
52 |
-
"vanilla": (
|
53 |
-
lambda x: x.transpose(1, 2),
|
54 |
-
lambda x: x.transpose(1, 2),
|
55 |
-
),
|
56 |
-
}
|
57 |
-
|
58 |
-
|
59 |
-
def attention(
|
60 |
-
q,
|
61 |
-
k,
|
62 |
-
v,
|
63 |
-
mode="torch",
|
64 |
-
drop_rate=0,
|
65 |
-
attn_mask=None,
|
66 |
-
causal=False,
|
67 |
-
cu_seqlens_q=None,
|
68 |
-
cu_seqlens_kv=None,
|
69 |
-
max_seqlen_q=None,
|
70 |
-
max_seqlen_kv=None,
|
71 |
-
batch_size=1,
|
72 |
-
):
|
73 |
-
"""
|
74 |
-
Perform QKV self attention.
|
75 |
-
|
76 |
-
Args:
|
77 |
-
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
|
78 |
-
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
|
79 |
-
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
|
80 |
-
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
|
81 |
-
drop_rate (float): Dropout rate in attention map. (default: 0)
|
82 |
-
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
|
83 |
-
(default: None)
|
84 |
-
causal (bool): Whether to use causal attention. (default: False)
|
85 |
-
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
86 |
-
used to index into q.
|
87 |
-
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
88 |
-
used to index into kv.
|
89 |
-
max_seqlen_q (int): The maximum sequence length in the batch of q.
|
90 |
-
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
|
91 |
-
|
92 |
-
Returns:
|
93 |
-
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
|
94 |
-
"""
|
95 |
-
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
|
96 |
-
q = pre_attn_layout(q)
|
97 |
-
k = pre_attn_layout(k)
|
98 |
-
v = pre_attn_layout(v)
|
99 |
-
|
100 |
-
if mode == "torch":
|
101 |
-
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
102 |
-
attn_mask = attn_mask.to(q.dtype)
|
103 |
-
x = F.scaled_dot_product_attention(
|
104 |
-
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
|
105 |
-
)
|
106 |
-
elif mode == "flash":
|
107 |
-
assert flash_attn_varlen_func is not None
|
108 |
-
x: torch.Tensor = flash_attn_varlen_func(
|
109 |
-
q,
|
110 |
-
k,
|
111 |
-
v,
|
112 |
-
cu_seqlens_q,
|
113 |
-
cu_seqlens_kv,
|
114 |
-
max_seqlen_q,
|
115 |
-
max_seqlen_kv,
|
116 |
-
) # type: ignore
|
117 |
-
# x with shape [(bxs), a, d]
|
118 |
-
x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # type: ignore # reshape x to [b, s, a, d]
|
119 |
-
elif mode == "vanilla":
|
120 |
-
scale_factor = 1 / math.sqrt(q.size(-1))
|
121 |
-
|
122 |
-
b, a, s, _ = q.shape
|
123 |
-
s1 = k.size(2)
|
124 |
-
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
|
125 |
-
if causal:
|
126 |
-
# Only applied to self attention
|
127 |
-
assert attn_mask is None, (
|
128 |
-
"Causal mask and attn_mask cannot be used together"
|
129 |
-
)
|
130 |
-
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
|
131 |
-
diagonal=0
|
132 |
-
)
|
133 |
-
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
134 |
-
attn_bias.to(q.dtype)
|
135 |
-
|
136 |
-
if attn_mask is not None:
|
137 |
-
if attn_mask.dtype == torch.bool:
|
138 |
-
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
139 |
-
else:
|
140 |
-
attn_bias += attn_mask
|
141 |
-
|
142 |
-
# TODO: Maybe force q and k to be float32 to avoid numerical overflow
|
143 |
-
attn = (q @ k.transpose(-2, -1)) * scale_factor
|
144 |
-
attn += attn_bias
|
145 |
-
attn = attn.softmax(dim=-1)
|
146 |
-
attn = torch.dropout(attn, p=drop_rate, train=True)
|
147 |
-
x = attn @ v
|
148 |
-
else:
|
149 |
-
raise NotImplementedError(f"Unsupported attention mode: {mode}")
|
150 |
-
|
151 |
-
x = post_attn_layout(x)
|
152 |
-
b, s, a, d = x.shape
|
153 |
-
out = x.reshape(b, s, -1)
|
154 |
-
return out
|
155 |
-
|
156 |
-
|
157 |
-
def apply_gate(x, gate=None, tanh=False):
|
158 |
-
"""AI is creating summary for apply_gate
|
159 |
-
|
160 |
-
Args:
|
161 |
-
x (torch.Tensor): input tensor.
|
162 |
-
gate (torch.Tensor, optional): gate tensor. Defaults to None.
|
163 |
-
tanh (bool, optional): whether to use tanh function. Defaults to False.
|
164 |
-
|
165 |
-
Returns:
|
166 |
-
torch.Tensor: the output tensor after apply gate.
|
167 |
-
"""
|
168 |
-
if gate is None:
|
169 |
-
return x
|
170 |
-
if tanh:
|
171 |
-
return x * gate.unsqueeze(1).tanh()
|
172 |
-
else:
|
173 |
-
return x * gate.unsqueeze(1)
|
174 |
-
|
175 |
-
|
176 |
-
class MLP(nn.Module):
|
177 |
-
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
178 |
-
|
179 |
-
def __init__(
|
180 |
-
self,
|
181 |
-
in_channels,
|
182 |
-
hidden_channels=None,
|
183 |
-
out_features=None,
|
184 |
-
act_layer=nn.GELU,
|
185 |
-
norm_layer=None,
|
186 |
-
bias=True,
|
187 |
-
drop=0.0,
|
188 |
-
use_conv=False,
|
189 |
-
device=None,
|
190 |
-
dtype=None,
|
191 |
-
):
|
192 |
-
super().__init__()
|
193 |
-
out_features = out_features or in_channels
|
194 |
-
hidden_channels = hidden_channels or in_channels
|
195 |
-
bias = (bias, bias)
|
196 |
-
drop_probs = (drop, drop)
|
197 |
-
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
198 |
-
|
199 |
-
self.fc1 = linear_layer(
|
200 |
-
in_channels, hidden_channels, bias=bias[0], device=device, dtype=dtype
|
201 |
-
)
|
202 |
-
self.act = act_layer()
|
203 |
-
self.drop1 = nn.Dropout(drop_probs[0])
|
204 |
-
self.norm = (
|
205 |
-
norm_layer(hidden_channels, device=device, dtype=dtype)
|
206 |
-
if norm_layer is not None
|
207 |
-
else nn.Identity()
|
208 |
-
)
|
209 |
-
self.fc2 = linear_layer(
|
210 |
-
hidden_channels, out_features, bias=bias[1], device=device, dtype=dtype
|
211 |
-
)
|
212 |
-
self.drop2 = nn.Dropout(drop_probs[1])
|
213 |
-
|
214 |
-
def forward(self, x):
|
215 |
-
x = self.fc1(x)
|
216 |
-
x = self.act(x)
|
217 |
-
x = self.drop1(x)
|
218 |
-
x = self.norm(x)
|
219 |
-
x = self.fc2(x)
|
220 |
-
x = self.drop2(x)
|
221 |
-
return x
|
222 |
-
|
223 |
-
|
224 |
-
class TextProjection(nn.Module):
|
225 |
-
"""
|
226 |
-
Projects text embeddings. Also handles dropout for classifier-free guidance.
|
227 |
-
|
228 |
-
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
229 |
-
"""
|
230 |
-
|
231 |
-
def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
|
232 |
-
factory_kwargs = {"dtype": dtype, "device": device}
|
233 |
-
super().__init__()
|
234 |
-
self.linear_1 = nn.Linear(
|
235 |
-
in_features=in_channels,
|
236 |
-
out_features=hidden_size,
|
237 |
-
bias=True,
|
238 |
-
**factory_kwargs,
|
239 |
-
)
|
240 |
-
self.act_1 = act_layer()
|
241 |
-
self.linear_2 = nn.Linear(
|
242 |
-
in_features=hidden_size,
|
243 |
-
out_features=hidden_size,
|
244 |
-
bias=True,
|
245 |
-
**factory_kwargs,
|
246 |
-
)
|
247 |
-
|
248 |
-
def forward(self, caption):
|
249 |
-
hidden_states = self.linear_1(caption)
|
250 |
-
hidden_states = self.act_1(hidden_states)
|
251 |
-
hidden_states = self.linear_2(hidden_states)
|
252 |
-
return hidden_states
|
253 |
-
|
254 |
-
|
255 |
-
class TimestepEmbedder(nn.Module):
|
256 |
-
"""
|
257 |
-
Embeds scalar timesteps into vector representations.
|
258 |
-
"""
|
259 |
-
|
260 |
-
def __init__(
|
261 |
-
self,
|
262 |
-
hidden_size,
|
263 |
-
act_layer,
|
264 |
-
frequency_embedding_size=256,
|
265 |
-
max_period=10000,
|
266 |
-
out_size=None,
|
267 |
-
dtype=None,
|
268 |
-
device=None,
|
269 |
-
):
|
270 |
-
factory_kwargs = {"dtype": dtype, "device": device}
|
271 |
-
super().__init__()
|
272 |
-
self.frequency_embedding_size = frequency_embedding_size
|
273 |
-
self.max_period = max_period
|
274 |
-
if out_size is None:
|
275 |
-
out_size = hidden_size
|
276 |
-
|
277 |
-
self.mlp = nn.Sequential(
|
278 |
-
nn.Linear(
|
279 |
-
frequency_embedding_size, hidden_size, bias=True, **factory_kwargs
|
280 |
-
),
|
281 |
-
act_layer(),
|
282 |
-
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
|
283 |
-
)
|
284 |
-
nn.init.normal_(self.mlp[0].weight, std=0.02) # type: ignore
|
285 |
-
nn.init.normal_(self.mlp[2].weight, std=0.02) # type: ignore
|
286 |
-
|
287 |
-
@staticmethod
|
288 |
-
def timestep_embedding(t, dim, max_period=10000):
|
289 |
-
"""
|
290 |
-
Create sinusoidal timestep embeddings.
|
291 |
-
|
292 |
-
Args:
|
293 |
-
t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
294 |
-
dim (int): the dimension of the output.
|
295 |
-
max_period (int): controls the minimum frequency of the embeddings.
|
296 |
-
|
297 |
-
Returns:
|
298 |
-
embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
|
299 |
-
|
300 |
-
.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
301 |
-
"""
|
302 |
-
half = dim // 2
|
303 |
-
freqs = torch.exp(
|
304 |
-
-math.log(max_period)
|
305 |
-
* torch.arange(start=0, end=half, dtype=torch.float32)
|
306 |
-
/ half
|
307 |
-
).to(device=t.device)
|
308 |
-
args = t[:, None].float() * freqs[None]
|
309 |
-
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
310 |
-
if dim % 2:
|
311 |
-
embedding = torch.cat(
|
312 |
-
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
313 |
-
)
|
314 |
-
return embedding
|
315 |
-
|
316 |
-
def forward(self, t):
|
317 |
-
t_freq = self.timestep_embedding(
|
318 |
-
t, self.frequency_embedding_size, self.max_period
|
319 |
-
).type(self.mlp[0].weight.dtype) # type: ignore
|
320 |
-
t_emb = self.mlp(t_freq)
|
321 |
-
return t_emb
|
322 |
-
|
323 |
-
|
324 |
-
class EmbedND(nn.Module):
|
325 |
-
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
326 |
-
super().__init__()
|
327 |
-
self.dim = dim
|
328 |
-
self.theta = theta
|
329 |
-
self.axes_dim = axes_dim
|
330 |
-
|
331 |
-
def forward(self, ids: Tensor) -> Tensor:
|
332 |
-
n_axes = ids.shape[-1]
|
333 |
-
emb = torch.cat(
|
334 |
-
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
335 |
-
dim=-3,
|
336 |
-
)
|
337 |
-
|
338 |
-
return emb.unsqueeze(1)
|
339 |
-
|
340 |
-
|
341 |
-
class MLPEmbedder(nn.Module):
|
342 |
-
def __init__(self, in_dim: int, hidden_dim: int):
|
343 |
-
super().__init__()
|
344 |
-
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
345 |
-
self.silu = nn.SiLU()
|
346 |
-
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
347 |
-
|
348 |
-
def forward(self, x: Tensor) -> Tensor:
|
349 |
-
return self.out_layer(self.silu(self.in_layer(x)))
|
350 |
-
|
351 |
-
|
352 |
-
def rope(pos, dim: int, theta: int):
|
353 |
-
assert dim % 2 == 0
|
354 |
-
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
355 |
-
omega = 1.0 / (theta**scale)
|
356 |
-
out = torch.einsum("...n,d->...nd", pos, omega)
|
357 |
-
out = torch.stack(
|
358 |
-
[torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
|
359 |
-
)
|
360 |
-
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
361 |
-
return out.float()
|
362 |
-
|
363 |
-
|
364 |
-
def attention_after_rope(q, k, v, pe):
|
365 |
-
q, k = apply_rope(q, k, pe)
|
366 |
-
|
367 |
-
from .attention import attention
|
368 |
-
|
369 |
-
x = attention(q, k, v, mode="torch")
|
370 |
-
return x
|
371 |
-
|
372 |
-
|
373 |
-
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
|
374 |
-
def apply_rope(xq, xk, freqs_cis):
|
375 |
-
# 将 num_heads 和 seq_len 的维度交换回原函数的处理顺序
|
376 |
-
xq = xq.transpose(1, 2) # [batch, num_heads, seq_len, head_dim]
|
377 |
-
xk = xk.transpose(1, 2)
|
378 |
-
|
379 |
-
# 将 head_dim 拆分为复数部分(实部和虚部)
|
380 |
-
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
381 |
-
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
382 |
-
|
383 |
-
# 应用旋转位置编码(复数乘法)
|
384 |
-
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
385 |
-
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
386 |
-
|
387 |
-
# 恢复张量形状并转置回目标维度顺序
|
388 |
-
xq_out = xq_out.reshape(*xq.shape).type_as(xq).transpose(1, 2)
|
389 |
-
xk_out = xk_out.reshape(*xk.shape).type_as(xk).transpose(1, 2)
|
390 |
-
|
391 |
-
return xq_out, xk_out
|
392 |
-
|
393 |
-
|
394 |
-
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
|
395 |
-
def scale_add_residual(
|
396 |
-
x: torch.Tensor, scale: torch.Tensor, residual: torch.Tensor
|
397 |
-
) -> torch.Tensor:
|
398 |
-
return x * scale + residual
|
399 |
-
|
400 |
-
|
401 |
-
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
|
402 |
-
def layernorm_and_scale_shift(
|
403 |
-
x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor
|
404 |
-
) -> torch.Tensor:
|
405 |
-
return torch.nn.functional.layer_norm(x, (x.size(-1),)) * (scale + 1) + shift
|
406 |
-
|
407 |
-
|
408 |
-
class SelfAttention(nn.Module):
|
409 |
-
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
410 |
-
super().__init__()
|
411 |
-
self.num_heads = num_heads
|
412 |
-
head_dim = dim // num_heads
|
413 |
-
|
414 |
-
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
415 |
-
self.norm = QKNorm(head_dim)
|
416 |
-
self.proj = nn.Linear(dim, dim)
|
417 |
-
|
418 |
-
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
419 |
-
qkv = self.qkv(x)
|
420 |
-
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
|
421 |
-
q, k = self.norm(q, k, v)
|
422 |
-
x = attention_after_rope(q, k, v, pe=pe)
|
423 |
-
x = self.proj(x)
|
424 |
-
return x
|
425 |
-
|
426 |
-
|
427 |
-
@dataclass
|
428 |
-
class ModulationOut:
|
429 |
-
shift: Tensor
|
430 |
-
scale: Tensor
|
431 |
-
gate: Tensor
|
432 |
-
|
433 |
-
|
434 |
-
class RMSNorm(torch.nn.Module):
|
435 |
-
def __init__(self, dim: int):
|
436 |
-
super().__init__()
|
437 |
-
self.scale = nn.Parameter(torch.ones(dim))
|
438 |
-
|
439 |
-
# @staticmethod
|
440 |
-
# def rms_norm_fast(x, weight, eps):
|
441 |
-
# return LigerRMSNormFunction.apply(
|
442 |
-
# x,
|
443 |
-
# weight,
|
444 |
-
# eps,
|
445 |
-
# 0.0,
|
446 |
-
# "gemma",
|
447 |
-
# True,
|
448 |
-
# )
|
449 |
-
|
450 |
-
@staticmethod
|
451 |
-
def rms_norm(x, weight, eps):
|
452 |
-
x_dtype = x.dtype
|
453 |
-
x = x.float()
|
454 |
-
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
455 |
-
return (x * rrms).to(dtype=x_dtype) * weight
|
456 |
-
|
457 |
-
def forward(self, x: Tensor):
|
458 |
-
# return self.rms_norm_fast(x, self.scale, 1e-6)
|
459 |
-
return self.rms_norm(x, self.scale, 1e-6)
|
460 |
-
|
461 |
-
|
462 |
-
class QKNorm(torch.nn.Module):
|
463 |
-
def __init__(self, dim: int):
|
464 |
-
super().__init__()
|
465 |
-
self.query_norm = RMSNorm(dim)
|
466 |
-
self.key_norm = RMSNorm(dim)
|
467 |
-
|
468 |
-
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
469 |
-
q = self.query_norm(q)
|
470 |
-
k = self.key_norm(k)
|
471 |
-
return q.to(v), k.to(v)
|
472 |
-
|
473 |
-
|
474 |
-
class Modulation(nn.Module):
|
475 |
-
def __init__(self, dim: int, double: bool):
|
476 |
-
super().__init__()
|
477 |
-
self.is_double = double
|
478 |
-
self.multiplier = 6 if double else 3
|
479 |
-
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
480 |
-
|
481 |
-
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
482 |
-
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(
|
483 |
-
self.multiplier, dim=-1
|
484 |
-
)
|
485 |
-
|
486 |
-
return (
|
487 |
-
ModulationOut(*out[:3]),
|
488 |
-
ModulationOut(*out[3:]) if self.is_double else None,
|
489 |
-
)
|
490 |
-
|
491 |
-
|
492 |
-
class DoubleStreamBlock(nn.Module):
|
493 |
-
def __init__(
|
494 |
-
self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
|
495 |
-
):
|
496 |
-
super().__init__()
|
497 |
-
|
498 |
-
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
499 |
-
self.num_heads = num_heads
|
500 |
-
self.hidden_size = hidden_size
|
501 |
-
self.img_mod = Modulation(hidden_size, double=True)
|
502 |
-
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
503 |
-
self.img_attn = SelfAttention(
|
504 |
-
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
505 |
-
)
|
506 |
-
|
507 |
-
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
508 |
-
self.img_mlp = nn.Sequential(
|
509 |
-
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
510 |
-
nn.GELU(approximate="tanh"),
|
511 |
-
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
512 |
-
)
|
513 |
-
|
514 |
-
self.txt_mod = Modulation(hidden_size, double=True)
|
515 |
-
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
516 |
-
self.txt_attn = SelfAttention(
|
517 |
-
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
518 |
-
)
|
519 |
-
|
520 |
-
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
521 |
-
self.txt_mlp = nn.Sequential(
|
522 |
-
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
523 |
-
nn.GELU(approximate="tanh"),
|
524 |
-
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
525 |
-
)
|
526 |
-
|
527 |
-
def forward(
|
528 |
-
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
|
529 |
-
) -> tuple[Tensor, Tensor]:
|
530 |
-
img_mod1, img_mod2 = self.img_mod(vec)
|
531 |
-
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
532 |
-
|
533 |
-
# prepare image for attention
|
534 |
-
img_modulated = self.img_norm1(img)
|
535 |
-
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
536 |
-
img_qkv = self.img_attn.qkv(img_modulated)
|
537 |
-
img_q, img_k, img_v = rearrange(
|
538 |
-
img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads
|
539 |
-
)
|
540 |
-
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
541 |
-
|
542 |
-
# prepare txt for attention
|
543 |
-
txt_modulated = self.txt_norm1(txt)
|
544 |
-
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
545 |
-
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
546 |
-
txt_q, txt_k, txt_v = rearrange(
|
547 |
-
txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads
|
548 |
-
)
|
549 |
-
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
550 |
-
|
551 |
-
# run actual attention
|
552 |
-
q = torch.cat((txt_q, img_q), dim=1)
|
553 |
-
k = torch.cat((txt_k, img_k), dim=1)
|
554 |
-
v = torch.cat((txt_v, img_v), dim=1)
|
555 |
-
|
556 |
-
attn = attention_after_rope(q, k, v, pe=pe)
|
557 |
-
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
558 |
-
|
559 |
-
# calculate the img bloks
|
560 |
-
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
561 |
-
img_mlp = self.img_mlp(
|
562 |
-
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
|
563 |
-
)
|
564 |
-
img = scale_add_residual(img_mlp, img_mod2.gate, img)
|
565 |
-
|
566 |
-
# calculate the txt bloks
|
567 |
-
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
568 |
-
txt_mlp = self.txt_mlp(
|
569 |
-
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
|
570 |
-
)
|
571 |
-
txt = scale_add_residual(txt_mlp, txt_mod2.gate, txt)
|
572 |
-
return img, txt
|
573 |
-
|
574 |
-
|
575 |
-
class SingleStreamBlock(nn.Module):
|
576 |
-
"""
|
577 |
-
A DiT block with parallel linear layers as described in
|
578 |
-
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
579 |
-
"""
|
580 |
-
|
581 |
-
def __init__(
|
582 |
-
self,
|
583 |
-
hidden_size: int,
|
584 |
-
num_heads: int,
|
585 |
-
mlp_ratio: float = 4.0,
|
586 |
-
qk_scale: float | None = None,
|
587 |
-
):
|
588 |
-
super().__init__()
|
589 |
-
self.hidden_dim = hidden_size
|
590 |
-
self.num_heads = num_heads
|
591 |
-
head_dim = hidden_size // num_heads
|
592 |
-
self.scale = qk_scale or head_dim**-0.5
|
593 |
-
|
594 |
-
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
595 |
-
# qkv and mlp_in
|
596 |
-
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
597 |
-
# proj and mlp_out
|
598 |
-
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
599 |
-
|
600 |
-
self.norm = QKNorm(head_dim)
|
601 |
-
|
602 |
-
self.hidden_size = hidden_size
|
603 |
-
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
604 |
-
|
605 |
-
self.mlp_act = nn.GELU(approximate="tanh")
|
606 |
-
self.modulation = Modulation(hidden_size, double=False)
|
607 |
-
|
608 |
-
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
609 |
-
mod, _ = self.modulation(vec)
|
610 |
-
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
611 |
-
qkv, mlp = torch.split(
|
612 |
-
self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
|
613 |
-
)
|
614 |
-
|
615 |
-
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
|
616 |
-
q, k = self.norm(q, k, v)
|
617 |
-
|
618 |
-
# compute attention
|
619 |
-
attn = attention_after_rope(q, k, v, pe=pe)
|
620 |
-
# compute activation in mlp stream, cat again and run second linear layer
|
621 |
-
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
622 |
-
return scale_add_residual(output, mod.gate, x)
|
623 |
-
|
624 |
-
|
625 |
-
class LastLayer(nn.Module):
|
626 |
-
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
627 |
-
super().__init__()
|
628 |
-
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
629 |
-
self.linear = nn.Linear(
|
630 |
-
hidden_size, patch_size * patch_size * out_channels, bias=True
|
631 |
-
)
|
632 |
-
self.adaLN_modulation = nn.Sequential(
|
633 |
-
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
634 |
-
)
|
635 |
-
|
636 |
-
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
637 |
-
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
638 |
-
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
639 |
-
x = self.linear(x)
|
640 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/model_edit.py
DELETED
@@ -1,143 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
from dataclasses import dataclass
|
3 |
-
|
4 |
-
import numpy as np
|
5 |
-
import torch
|
6 |
-
from torch import Tensor, nn
|
7 |
-
|
8 |
-
from .connector_edit import Qwen2Connector
|
9 |
-
from .layers import DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock
|
10 |
-
|
11 |
-
|
12 |
-
@dataclass
|
13 |
-
class Step1XParams:
|
14 |
-
in_channels: int
|
15 |
-
out_channels: int
|
16 |
-
vec_in_dim: int
|
17 |
-
context_in_dim: int
|
18 |
-
hidden_size: int
|
19 |
-
mlp_ratio: float
|
20 |
-
num_heads: int
|
21 |
-
depth: int
|
22 |
-
depth_single_blocks: int
|
23 |
-
axes_dim: list[int]
|
24 |
-
theta: int
|
25 |
-
qkv_bias: bool
|
26 |
-
|
27 |
-
|
28 |
-
class Step1XEdit(nn.Module):
|
29 |
-
"""
|
30 |
-
Transformer model for flow matching on sequences.
|
31 |
-
"""
|
32 |
-
|
33 |
-
def __init__(self, params: Step1XParams):
|
34 |
-
super().__init__()
|
35 |
-
|
36 |
-
self.params = params
|
37 |
-
self.in_channels = params.in_channels
|
38 |
-
self.out_channels = params.out_channels
|
39 |
-
if params.hidden_size % params.num_heads != 0:
|
40 |
-
raise ValueError(
|
41 |
-
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
42 |
-
)
|
43 |
-
pe_dim = params.hidden_size // params.num_heads
|
44 |
-
if sum(params.axes_dim) != pe_dim:
|
45 |
-
raise ValueError(
|
46 |
-
f"Got {params.axes_dim} but expected positional dim {pe_dim}"
|
47 |
-
)
|
48 |
-
self.hidden_size = params.hidden_size
|
49 |
-
self.num_heads = params.num_heads
|
50 |
-
self.pe_embedder = EmbedND(
|
51 |
-
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
|
52 |
-
)
|
53 |
-
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
54 |
-
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
55 |
-
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
56 |
-
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
57 |
-
|
58 |
-
self.double_blocks = nn.ModuleList(
|
59 |
-
[
|
60 |
-
DoubleStreamBlock(
|
61 |
-
self.hidden_size,
|
62 |
-
self.num_heads,
|
63 |
-
mlp_ratio=params.mlp_ratio,
|
64 |
-
qkv_bias=params.qkv_bias,
|
65 |
-
)
|
66 |
-
for _ in range(params.depth)
|
67 |
-
]
|
68 |
-
)
|
69 |
-
|
70 |
-
self.single_blocks = nn.ModuleList(
|
71 |
-
[
|
72 |
-
SingleStreamBlock(
|
73 |
-
self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio
|
74 |
-
)
|
75 |
-
for _ in range(params.depth_single_blocks)
|
76 |
-
]
|
77 |
-
)
|
78 |
-
|
79 |
-
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
80 |
-
|
81 |
-
self.connector = Qwen2Connector()
|
82 |
-
|
83 |
-
@staticmethod
|
84 |
-
def timestep_embedding(
|
85 |
-
t: Tensor, dim, max_period=10000, time_factor: float = 1000.0
|
86 |
-
):
|
87 |
-
"""
|
88 |
-
Create sinusoidal timestep embeddings.
|
89 |
-
:param t: a 1-D Tensor of N indices, one per batch element.
|
90 |
-
These may be fractional.
|
91 |
-
:param dim: the dimension of the output.
|
92 |
-
:param max_period: controls the minimum frequency of the embeddings.
|
93 |
-
:return: an (N, D) Tensor of positional embeddings.
|
94 |
-
"""
|
95 |
-
t = time_factor * t
|
96 |
-
half = dim // 2
|
97 |
-
freqs = torch.exp(
|
98 |
-
-math.log(max_period)
|
99 |
-
* torch.arange(start=0, end=half, dtype=torch.float32)
|
100 |
-
/ half
|
101 |
-
).to(t.device)
|
102 |
-
|
103 |
-
args = t[:, None].float() * freqs[None]
|
104 |
-
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
105 |
-
if dim % 2:
|
106 |
-
embedding = torch.cat(
|
107 |
-
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
108 |
-
)
|
109 |
-
if torch.is_floating_point(t):
|
110 |
-
embedding = embedding.to(t)
|
111 |
-
return embedding
|
112 |
-
|
113 |
-
def forward(
|
114 |
-
self,
|
115 |
-
img: Tensor,
|
116 |
-
img_ids: Tensor,
|
117 |
-
txt: Tensor,
|
118 |
-
txt_ids: Tensor,
|
119 |
-
timesteps: Tensor,
|
120 |
-
y: Tensor,
|
121 |
-
) -> Tensor:
|
122 |
-
if img.ndim != 3 or txt.ndim != 3:
|
123 |
-
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
124 |
-
|
125 |
-
img = self.img_in(img)
|
126 |
-
vec = self.time_in(self.timestep_embedding(timesteps, 256))
|
127 |
-
|
128 |
-
vec = vec + self.vector_in(y)
|
129 |
-
txt = self.txt_in(txt)
|
130 |
-
|
131 |
-
ids = torch.cat((txt_ids, img_ids), dim=1)
|
132 |
-
pe = self.pe_embedder(ids)
|
133 |
-
|
134 |
-
for block in self.double_blocks:
|
135 |
-
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
136 |
-
|
137 |
-
img = torch.cat((txt, img), 1)
|
138 |
-
for block in self.single_blocks:
|
139 |
-
img = block(img, vec=vec, pe=pe)
|
140 |
-
img = img[:, txt.shape[1] :, ...]
|
141 |
-
|
142 |
-
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
143 |
-
return img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
no_cookie.png
DELETED
Git LFS Details
|
poster.jpg
DELETED
Binary file (65.4 kB)
|
|
poster_orig.jpg
DELETED
Git LFS Details
|