Eyalgut commited on
Commit
ef6db9b
·
verified ·
1 Parent(s): 8baa624

Update transformer_bria.py

Browse files
Files changed (1) hide show
  1. transformer_bria.py +8 -28
transformer_bria.py CHANGED
@@ -10,38 +10,18 @@ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_l
10
  from diffusers.models.modeling_outputs import Transformer2DModelOutput
11
  from diffusers.models.embeddings import TimestepEmbedding, get_timestep_embedding
12
  from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
13
-
14
- # Support different diffusers versions
15
- try:
16
- from diffusers.models.embeddings import FluxPosEmbed as EmbedND
17
- except:
18
- from diffusers.models.transformers.transformer_flux import rope
19
- class EmbedND(nn.Module):
20
- def __init__(self, theta: int, axes_dim: List[int]):
21
- super().__init__()
22
- self.theta = theta
23
- self.axes_dim = axes_dim
24
-
25
- def forward(self, ids: torch.Tensor) -> torch.Tensor:
26
- n_axes = ids.shape[-1]
27
- emb = torch.cat(
28
- [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
29
- dim=-3,
30
- )
31
- return emb.unsqueeze(1)
32
-
33
-
34
 
35
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
 
37
  class Timesteps(nn.Module):
38
- def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1,max_period=10000):
39
  super().__init__()
40
  self.num_channels = num_channels
41
  self.flip_sin_to_cos = flip_sin_to_cos
42
  self.downscale_freq_shift = downscale_freq_shift
43
  self.scale = scale
44
- self.max_period=max_period
45
 
46
  def forward(self, timesteps):
47
  t_emb = get_timestep_embedding(
@@ -50,15 +30,15 @@ class Timesteps(nn.Module):
50
  flip_sin_to_cos=self.flip_sin_to_cos,
51
  downscale_freq_shift=self.downscale_freq_shift,
52
  scale=self.scale,
53
- max_period=self.max_period
54
  )
55
  return t_emb
56
 
57
  class TimestepProjEmbeddings(nn.Module):
58
- def __init__(self, embedding_dim, max_period):
59
  super().__init__()
60
 
61
- self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0,max_period=max_period)
62
  self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
63
 
64
  def forward(self, timestep, dtype):
@@ -106,7 +86,7 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
106
  guidance_embeds: bool = False,
107
  axes_dims_rope: List[int] = [16, 56, 56],
108
  rope_theta = 10000,
109
- max_period = 10000
110
  ):
111
  super().__init__()
112
  self.out_channels = in_channels
@@ -116,7 +96,7 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
116
 
117
 
118
  self.time_embed = TimestepProjEmbeddings(
119
- embedding_dim=self.inner_dim,max_period=max_period
120
  )
121
 
122
  # if pooled_projection_dim:
 
10
  from diffusers.models.modeling_outputs import Transformer2DModelOutput
11
  from diffusers.models.embeddings import TimestepEmbedding, get_timestep_embedding
12
  from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
13
+ from bria_utils import FluxPosEmbed as EmbedND
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
16
 
17
  class Timesteps(nn.Module):
18
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1,time_theta=10000):
19
  super().__init__()
20
  self.num_channels = num_channels
21
  self.flip_sin_to_cos = flip_sin_to_cos
22
  self.downscale_freq_shift = downscale_freq_shift
23
  self.scale = scale
24
+ self.time_theta=time_theta
25
 
26
  def forward(self, timesteps):
27
  t_emb = get_timestep_embedding(
 
30
  flip_sin_to_cos=self.flip_sin_to_cos,
31
  downscale_freq_shift=self.downscale_freq_shift,
32
  scale=self.scale,
33
+ max_period=self.time_theta
34
  )
35
  return t_emb
36
 
37
  class TimestepProjEmbeddings(nn.Module):
38
+ def __init__(self, embedding_dim, time_theta):
39
  super().__init__()
40
 
41
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0,time_theta=time_theta)
42
  self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
43
 
44
  def forward(self, timestep, dtype):
 
86
  guidance_embeds: bool = False,
87
  axes_dims_rope: List[int] = [16, 56, 56],
88
  rope_theta = 10000,
89
+ time_theta = 10000
90
  ):
91
  super().__init__()
92
  self.out_channels = in_channels
 
96
 
97
 
98
  self.time_embed = TimestepProjEmbeddings(
99
+ embedding_dim=self.inner_dim,time_theta=time_theta
100
  )
101
 
102
  # if pooled_projection_dim: