Spaces:
Runtime error
Runtime error
Merge pull request #7 from LightricksResearch/feature/fix-transformer-init-bug
Browse files
xora/models/transformers/transformer3d.py
CHANGED
|
@@ -186,14 +186,14 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
| 186 |
|
| 187 |
# Zero-out adaLN modulation layers in PixArt blocks:
|
| 188 |
for block in self.transformer_blocks:
|
| 189 |
-
if mode == "xora":
|
| 190 |
nn.init.constant_(block.attn1.to_out[0].weight, 0)
|
| 191 |
nn.init.constant_(block.attn1.to_out[0].bias, 0)
|
| 192 |
|
| 193 |
nn.init.constant_(block.attn2.to_out[0].weight, 0)
|
| 194 |
nn.init.constant_(block.attn2.to_out[0].bias, 0)
|
| 195 |
|
| 196 |
-
if mode == "xora":
|
| 197 |
nn.init.constant_(block.ff.net[2].weight, 0)
|
| 198 |
nn.init.constant_(block.ff.net[2].bias, 0)
|
| 199 |
|
|
|
|
| 186 |
|
| 187 |
# Zero-out adaLN modulation layers in PixArt blocks:
|
| 188 |
for block in self.transformer_blocks:
|
| 189 |
+
if mode.lower() == "xora":
|
| 190 |
nn.init.constant_(block.attn1.to_out[0].weight, 0)
|
| 191 |
nn.init.constant_(block.attn1.to_out[0].bias, 0)
|
| 192 |
|
| 193 |
nn.init.constant_(block.attn2.to_out[0].weight, 0)
|
| 194 |
nn.init.constant_(block.attn2.to_out[0].bias, 0)
|
| 195 |
|
| 196 |
+
if mode.lower() == "xora":
|
| 197 |
nn.init.constant_(block.ff.net[2].weight, 0)
|
| 198 |
nn.init.constant_(block.ff.net[2].bias, 0)
|
| 199 |
|