Spaces:
Running
Running
Update transformer_bria.py
Browse files- transformer_bria.py +8 -28
transformer_bria.py
CHANGED
@@ -10,38 +10,18 @@ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_l
|
|
10 |
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
11 |
from diffusers.models.embeddings import TimestepEmbedding, get_timestep_embedding
|
12 |
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
13 |
-
|
14 |
-
# Support different diffusers versions
|
15 |
-
try:
|
16 |
-
from diffusers.models.embeddings import FluxPosEmbed as EmbedND
|
17 |
-
except:
|
18 |
-
from diffusers.models.transformers.transformer_flux import rope
|
19 |
-
class EmbedND(nn.Module):
|
20 |
-
def __init__(self, theta: int, axes_dim: List[int]):
|
21 |
-
super().__init__()
|
22 |
-
self.theta = theta
|
23 |
-
self.axes_dim = axes_dim
|
24 |
-
|
25 |
-
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
26 |
-
n_axes = ids.shape[-1]
|
27 |
-
emb = torch.cat(
|
28 |
-
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
29 |
-
dim=-3,
|
30 |
-
)
|
31 |
-
return emb.unsqueeze(1)
|
32 |
-
|
33 |
-
|
34 |
|
35 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
36 |
|
37 |
class Timesteps(nn.Module):
|
38 |
-
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1,
|
39 |
super().__init__()
|
40 |
self.num_channels = num_channels
|
41 |
self.flip_sin_to_cos = flip_sin_to_cos
|
42 |
self.downscale_freq_shift = downscale_freq_shift
|
43 |
self.scale = scale
|
44 |
-
self.
|
45 |
|
46 |
def forward(self, timesteps):
|
47 |
t_emb = get_timestep_embedding(
|
@@ -50,15 +30,15 @@ class Timesteps(nn.Module):
|
|
50 |
flip_sin_to_cos=self.flip_sin_to_cos,
|
51 |
downscale_freq_shift=self.downscale_freq_shift,
|
52 |
scale=self.scale,
|
53 |
-
max_period=self.
|
54 |
)
|
55 |
return t_emb
|
56 |
|
57 |
class TimestepProjEmbeddings(nn.Module):
|
58 |
-
def __init__(self, embedding_dim,
|
59 |
super().__init__()
|
60 |
|
61 |
-
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0,
|
62 |
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
63 |
|
64 |
def forward(self, timestep, dtype):
|
@@ -106,7 +86,7 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
106 |
guidance_embeds: bool = False,
|
107 |
axes_dims_rope: List[int] = [16, 56, 56],
|
108 |
rope_theta = 10000,
|
109 |
-
|
110 |
):
|
111 |
super().__init__()
|
112 |
self.out_channels = in_channels
|
@@ -116,7 +96,7 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
116 |
|
117 |
|
118 |
self.time_embed = TimestepProjEmbeddings(
|
119 |
-
embedding_dim=self.inner_dim,
|
120 |
)
|
121 |
|
122 |
# if pooled_projection_dim:
|
|
|
10 |
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
11 |
from diffusers.models.embeddings import TimestepEmbedding, get_timestep_embedding
|
12 |
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
13 |
+
from bria_utils import FluxPosEmbed as EmbedND
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
16 |
|
17 |
class Timesteps(nn.Module):
|
18 |
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1,time_theta=10000):
|
19 |
super().__init__()
|
20 |
self.num_channels = num_channels
|
21 |
self.flip_sin_to_cos = flip_sin_to_cos
|
22 |
self.downscale_freq_shift = downscale_freq_shift
|
23 |
self.scale = scale
|
24 |
+
self.time_theta=time_theta
|
25 |
|
26 |
def forward(self, timesteps):
|
27 |
t_emb = get_timestep_embedding(
|
|
|
30 |
flip_sin_to_cos=self.flip_sin_to_cos,
|
31 |
downscale_freq_shift=self.downscale_freq_shift,
|
32 |
scale=self.scale,
|
33 |
+
max_period=self.time_theta
|
34 |
)
|
35 |
return t_emb
|
36 |
|
37 |
class TimestepProjEmbeddings(nn.Module):
|
38 |
+
def __init__(self, embedding_dim, time_theta):
|
39 |
super().__init__()
|
40 |
|
41 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0,time_theta=time_theta)
|
42 |
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
43 |
|
44 |
def forward(self, timestep, dtype):
|
|
|
86 |
guidance_embeds: bool = False,
|
87 |
axes_dims_rope: List[int] = [16, 56, 56],
|
88 |
rope_theta = 10000,
|
89 |
+
time_theta = 10000
|
90 |
):
|
91 |
super().__init__()
|
92 |
self.out_channels = in_channels
|
|
|
96 |
|
97 |
|
98 |
self.time_embed = TimestepProjEmbeddings(
|
99 |
+
embedding_dim=self.inner_dim,time_theta=time_theta
|
100 |
)
|
101 |
|
102 |
# if pooled_projection_dim:
|