Spaces:
Running
Running
Shuming Ma
Shuming Ma
commited on
fix: DeepNet doesn't scale weights of embedding/output layers (#150)
Browse files
src/dalle_mini/model/modeling.py
CHANGED
|
@@ -883,9 +883,7 @@ class FlaxBartEncoder(FlaxBartEncoder):
|
|
| 883 |
self.embed_positions = nn.Embed(
|
| 884 |
self.config.max_text_length + self.offset,
|
| 885 |
embed_dim,
|
| 886 |
-
embedding_init=
|
| 887 |
-
if self.config.use_deepnet_scaling
|
| 888 |
-
else jax.nn.initializers.normal(self.config.init_std),
|
| 889 |
)
|
| 890 |
self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
|
| 891 |
self.layernorm_embedding = norm(
|
|
@@ -917,9 +915,7 @@ class FlaxBartDecoder(FlaxBartDecoder):
|
|
| 917 |
self.embed_positions = nn.Embed(
|
| 918 |
self.config.image_length + self.offset, # image length for BOS
|
| 919 |
embed_dim,
|
| 920 |
-
embedding_init=
|
| 921 |
-
if self.config.use_deepnet_scaling
|
| 922 |
-
else jax.nn.initializers.normal(self.config.init_std),
|
| 923 |
)
|
| 924 |
|
| 925 |
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
|
|
@@ -939,16 +935,12 @@ class FlaxBartModule(FlaxBartModule):
|
|
| 939 |
encoder_embed_tokens = nn.Embed(
|
| 940 |
self.config.encoder_vocab_size,
|
| 941 |
self.config.d_model,
|
| 942 |
-
embedding_init=
|
| 943 |
-
if self.config.use_deepnet_scaling
|
| 944 |
-
else jax.nn.initializers.normal(self.config.init_std),
|
| 945 |
)
|
| 946 |
decoder_embed_tokens = nn.Embed(
|
| 947 |
self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
|
| 948 |
self.config.d_model,
|
| 949 |
-
embedding_init=
|
| 950 |
-
if self.config.use_deepnet_scaling
|
| 951 |
-
else jax.nn.initializers.normal(self.config.init_std),
|
| 952 |
)
|
| 953 |
|
| 954 |
self.encoder = FlaxBartEncoder(
|
|
@@ -1288,9 +1280,7 @@ class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationMod
|
|
| 1288 |
+ 1, # image vocab size + 1 for BOS to have same size as decoder inputs (for sharding)
|
| 1289 |
use_bias=False,
|
| 1290 |
dtype=self.dtype,
|
| 1291 |
-
kernel_init=
|
| 1292 |
-
if self.config.use_deepnet_scaling
|
| 1293 |
-
else jax.nn.initializers.normal(self.config.init_std),
|
| 1294 |
)
|
| 1295 |
|
| 1296 |
def __call__(
|
|
|
|
| 883 |
self.embed_positions = nn.Embed(
|
| 884 |
self.config.max_text_length + self.offset,
|
| 885 |
embed_dim,
|
| 886 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
|
|
|
|
|
|
| 887 |
)
|
| 888 |
self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
|
| 889 |
self.layernorm_embedding = norm(
|
|
|
|
| 915 |
self.embed_positions = nn.Embed(
|
| 916 |
self.config.image_length + self.offset, # image length for BOS
|
| 917 |
embed_dim,
|
| 918 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
|
|
|
|
|
|
| 919 |
)
|
| 920 |
|
| 921 |
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
|
|
|
|
| 935 |
encoder_embed_tokens = nn.Embed(
|
| 936 |
self.config.encoder_vocab_size,
|
| 937 |
self.config.d_model,
|
| 938 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
|
|
|
|
|
|
| 939 |
)
|
| 940 |
decoder_embed_tokens = nn.Embed(
|
| 941 |
self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
|
| 942 |
self.config.d_model,
|
| 943 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
|
|
|
|
|
|
| 944 |
)
|
| 945 |
|
| 946 |
self.encoder = FlaxBartEncoder(
|
|
|
|
| 1280 |
+ 1, # image vocab size + 1 for BOS to have same size as decoder inputs (for sharding)
|
| 1281 |
use_bias=False,
|
| 1282 |
dtype=self.dtype,
|
| 1283 |
+
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
|
|
|
|
|
|
| 1284 |
)
|
| 1285 |
|
| 1286 |
def __call__(
|