Spaces:
Running
on
Zero
Running
on
Zero
Fix init and repro (#48)
Browse files* Fix init and repro
* comment + black
---------
Co-authored-by: Srini Iyer <[email protected]>
- bytelatent/base_transformer.py +19 -14
- bytelatent/distributed.py +15 -7
- bytelatent/model/blt.py +8 -48
- bytelatent/model/latent_transformer.py +12 -11
- bytelatent/model/local_models.py +15 -12
- bytelatent/transformer.py +19 -8
bytelatent/base_transformer.py
CHANGED
|
@@ -445,7 +445,7 @@ class Attention(nn.Module):
|
|
| 445 |
return output
|
| 446 |
|
| 447 |
def reset_parameters(self, init_std=None, factor=1.0):
|
| 448 |
-
init_std = init_std or (self.dim ** (-0.5))
|
| 449 |
|
| 450 |
for w in [self.wq, self.wk, self.wv]:
|
| 451 |
nn.init.trunc_normal_(
|
|
@@ -459,7 +459,7 @@ class Attention(nn.Module):
|
|
| 459 |
nn.init.trunc_normal_(
|
| 460 |
self.wo.weight,
|
| 461 |
mean=0.0,
|
| 462 |
-
std=init_std
|
| 463 |
a=-3 * init_std,
|
| 464 |
b=3 * init_std,
|
| 465 |
)
|
|
@@ -509,18 +509,16 @@ class FeedForward(nn.Module):
|
|
| 509 |
return output
|
| 510 |
|
| 511 |
def reset_parameters(self, init_std=None, factor=1.0):
|
| 512 |
-
in_init_std = init_std or (self.dim ** (-0.5))
|
| 513 |
-
out_init_std = init_std or (self.hidden_dim ** (-0.5))
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
b=3 * in_init_std,
|
| 523 |
-
)
|
| 524 |
nn.init.trunc_normal_(
|
| 525 |
self.w2.weight,
|
| 526 |
mean=0.0,
|
|
@@ -528,6 +526,13 @@ class FeedForward(nn.Module):
|
|
| 528 |
a=-3 * out_init_std,
|
| 529 |
b=3 * out_init_std,
|
| 530 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 531 |
|
| 532 |
|
| 533 |
class TransformerBlock(nn.Module):
|
|
|
|
| 445 |
return output
|
| 446 |
|
| 447 |
def reset_parameters(self, init_std=None, factor=1.0):
|
| 448 |
+
init_std = init_std or (self.dim ** (-0.5)) / factor
|
| 449 |
|
| 450 |
for w in [self.wq, self.wk, self.wv]:
|
| 451 |
nn.init.trunc_normal_(
|
|
|
|
| 459 |
nn.init.trunc_normal_(
|
| 460 |
self.wo.weight,
|
| 461 |
mean=0.0,
|
| 462 |
+
std=init_std,
|
| 463 |
a=-3 * init_std,
|
| 464 |
b=3 * init_std,
|
| 465 |
)
|
|
|
|
| 509 |
return output
|
| 510 |
|
| 511 |
def reset_parameters(self, init_std=None, factor=1.0):
|
| 512 |
+
in_init_std = init_std or (self.dim ** (-0.5)) / factor
|
| 513 |
+
out_init_std = init_std or (self.hidden_dim ** (-0.5)) / factor
|
| 514 |
+
|
| 515 |
+
nn.init.trunc_normal_(
|
| 516 |
+
self.w1.weight,
|
| 517 |
+
mean=0.0,
|
| 518 |
+
std=in_init_std,
|
| 519 |
+
a=-3 * in_init_std,
|
| 520 |
+
b=3 * in_init_std,
|
| 521 |
+
)
|
|
|
|
|
|
|
| 522 |
nn.init.trunc_normal_(
|
| 523 |
self.w2.weight,
|
| 524 |
mean=0.0,
|
|
|
|
| 526 |
a=-3 * out_init_std,
|
| 527 |
b=3 * out_init_std,
|
| 528 |
)
|
| 529 |
+
nn.init.trunc_normal_(
|
| 530 |
+
self.w3.weight,
|
| 531 |
+
mean=0.0,
|
| 532 |
+
std=in_init_std,
|
| 533 |
+
a=-3 * in_init_std,
|
| 534 |
+
b=3 * in_init_std,
|
| 535 |
+
)
|
| 536 |
|
| 537 |
|
| 538 |
class TransformerBlock(nn.Module):
|
bytelatent/distributed.py
CHANGED
|
@@ -463,13 +463,21 @@ def parallelize_model(
|
|
| 463 |
raise ValueError(f"Invalid fsdp_type: {distributed_args.fsdp_type}")
|
| 464 |
|
| 465 |
if distributed_args.selective_activation_checkpointing:
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 473 |
|
| 474 |
if distributed_args.compile:
|
| 475 |
torch._dynamo.config.cache_size_limit = (
|
|
|
|
| 463 |
raise ValueError(f"Invalid fsdp_type: {distributed_args.fsdp_type}")
|
| 464 |
|
| 465 |
if distributed_args.selective_activation_checkpointing:
|
| 466 |
+
# only works for blt models
|
| 467 |
+
# assuming that entropy models will not use checkpointing
|
| 468 |
+
for module in [
|
| 469 |
+
model.global_transformer,
|
| 470 |
+
model.local_encoder,
|
| 471 |
+
model.local_decoder,
|
| 472 |
+
]:
|
| 473 |
+
for i in range(len(module.layers)):
|
| 474 |
+
module.layers[i] = checkpoint_wrapper(
|
| 475 |
+
module.layers[i],
|
| 476 |
+
context_fn=partial(
|
| 477 |
+
create_selective_checkpoint_contexts,
|
| 478 |
+
get_default_policy(no_recompute_ops),
|
| 479 |
+
),
|
| 480 |
+
)
|
| 481 |
|
| 482 |
if distributed_args.compile:
|
| 483 |
torch._dynamo.config.cache_size_limit = (
|
bytelatent/model/blt.py
CHANGED
|
@@ -825,12 +825,6 @@ class ByteLatentTransformer(nn.Module):
|
|
| 825 |
local_encoder_dim=self.local_encoder.dim,
|
| 826 |
encoder_hash_byte_group_size=None,
|
| 827 |
)
|
| 828 |
-
self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim)
|
| 829 |
-
|
| 830 |
-
# Transformer layers
|
| 831 |
-
self.layers = nn.ModuleList(
|
| 832 |
-
[TransformerBlock(args) for _ in range(args.n_layers)]
|
| 833 |
-
)
|
| 834 |
|
| 835 |
# Encoder ngram embedding tables
|
| 836 |
self.encoder_ngram_embedding = None
|
|
@@ -848,9 +842,6 @@ class ByteLatentTransformer(nn.Module):
|
|
| 848 |
|
| 849 |
# Output layer
|
| 850 |
assert args.vocab_size > 0, "vocab_size must be greater than 0"
|
| 851 |
-
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
| 852 |
-
if args.weight_tying:
|
| 853 |
-
self.output.weight = self.tok_embeddings.weight
|
| 854 |
|
| 855 |
# Patcher module
|
| 856 |
if args.patch_in_forward:
|
|
@@ -954,11 +945,10 @@ class ByteLatentTransformer(nn.Module):
|
|
| 954 |
local_encoder_embeds = local_encoder_embeds + ngram_embeds
|
| 955 |
|
| 956 |
# Local encoder
|
| 957 |
-
h_cross = None
|
| 958 |
(h_encoder, h_cross), cache_encoder = self.local_encoder(
|
| 959 |
tokens=local_encoder_tokens,
|
| 960 |
embeds=local_encoder_embeds,
|
| 961 |
-
patch_embeds=
|
| 962 |
cross_mask=cross_attn_mask_enc,
|
| 963 |
num_patches=patch_lengths.shape[1],
|
| 964 |
patch_ids=patch_ids,
|
|
@@ -1033,47 +1023,17 @@ class ByteLatentTransformer(nn.Module):
|
|
| 1033 |
)
|
| 1034 |
return output
|
| 1035 |
|
| 1036 |
-
def reset_parameters(self, init_std=None):
|
| 1037 |
-
# Either use fixed base std or sqrt model dim
|
| 1038 |
-
init_std = init_std or (self.dim ** (-0.5))
|
| 1039 |
-
nn.init.trunc_normal_(
|
| 1040 |
-
self.tok_embeddings.weight,
|
| 1041 |
-
mean=0.0,
|
| 1042 |
-
std=init_std,
|
| 1043 |
-
a=-3 * init_std,
|
| 1044 |
-
b=3 * init_std,
|
| 1045 |
-
)
|
| 1046 |
-
if not self.weight_tying:
|
| 1047 |
-
nn.init.trunc_normal_(
|
| 1048 |
-
self.output.weight,
|
| 1049 |
-
mean=0.0,
|
| 1050 |
-
std=init_std,
|
| 1051 |
-
a=-3 * init_std,
|
| 1052 |
-
b=3 * init_std,
|
| 1053 |
-
)
|
| 1054 |
-
|
| 1055 |
def init_weights(self):
|
| 1056 |
-
self.
|
| 1057 |
-
self.
|
| 1058 |
-
|
| 1059 |
-
factor = {
|
| 1060 |
-
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
|
| 1061 |
-
InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
|
| 1062 |
-
InitStdFactor.DIM_RATIO: self.dim / 4096,
|
| 1063 |
-
InitStdFactor.DISABLED: 1.0,
|
| 1064 |
-
}[self.init_std_factor]
|
| 1065 |
-
|
| 1066 |
-
layer.init_weights(self.init_base_std, factor)
|
| 1067 |
-
|
| 1068 |
-
self.local_decoder.init_weights(self.init_base_std)
|
| 1069 |
-
self.global_transformer.init_weights(self.init_base_std)
|
| 1070 |
-
self.local_encoder.init_weights(self.init_base_std)
|
| 1071 |
|
|
|
|
| 1072 |
for emb in self.encoder_hash_tok_embedding:
|
| 1073 |
nn.init.trunc_normal_(
|
| 1074 |
emb.weight,
|
| 1075 |
mean=0.0,
|
| 1076 |
-
std=
|
| 1077 |
-
a=-3 *
|
| 1078 |
-
b=3 *
|
| 1079 |
)
|
|
|
|
| 825 |
local_encoder_dim=self.local_encoder.dim,
|
| 826 |
encoder_hash_byte_group_size=None,
|
| 827 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 828 |
|
| 829 |
# Encoder ngram embedding tables
|
| 830 |
self.encoder_ngram_embedding = None
|
|
|
|
| 842 |
|
| 843 |
# Output layer
|
| 844 |
assert args.vocab_size > 0, "vocab_size must be greater than 0"
|
|
|
|
|
|
|
|
|
|
| 845 |
|
| 846 |
# Patcher module
|
| 847 |
if args.patch_in_forward:
|
|
|
|
| 945 |
local_encoder_embeds = local_encoder_embeds + ngram_embeds
|
| 946 |
|
| 947 |
# Local encoder
|
|
|
|
| 948 |
(h_encoder, h_cross), cache_encoder = self.local_encoder(
|
| 949 |
tokens=local_encoder_tokens,
|
| 950 |
embeds=local_encoder_embeds,
|
| 951 |
+
patch_embeds=None,
|
| 952 |
cross_mask=cross_attn_mask_enc,
|
| 953 |
num_patches=patch_lengths.shape[1],
|
| 954 |
patch_ids=patch_ids,
|
|
|
|
| 1023 |
)
|
| 1024 |
return output
|
| 1025 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1026 |
def init_weights(self):
|
| 1027 |
+
self.local_encoder.init_weights()
|
| 1028 |
+
self.global_transformer.init_weights()
|
| 1029 |
+
self.local_decoder.init_weights()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1030 |
|
| 1031 |
+
emb_std = self.local_encoder.dim ** (-0.5)
|
| 1032 |
for emb in self.encoder_hash_tok_embedding:
|
| 1033 |
nn.init.trunc_normal_(
|
| 1034 |
emb.weight,
|
| 1035 |
mean=0.0,
|
| 1036 |
+
std=emb_std,
|
| 1037 |
+
a=-3 * emb_std,
|
| 1038 |
+
b=3 * emb_std,
|
| 1039 |
)
|
bytelatent/model/latent_transformer.py
CHANGED
|
@@ -78,10 +78,10 @@ class CrossAttention(nn.Module):
|
|
| 78 |
# B S D
|
| 79 |
bsz, seq_len, _ = x.shape
|
| 80 |
_, slen_kv, _ = kv.shape
|
| 81 |
-
|
| 82 |
kv = self.cross_attn_norm_kv(kv)
|
| 83 |
|
| 84 |
-
xq = self.wq(
|
| 85 |
xk = self.wk(kv)
|
| 86 |
xv = self.wv(kv)
|
| 87 |
|
|
@@ -104,7 +104,7 @@ class CrossAttention(nn.Module):
|
|
| 104 |
return x + output
|
| 105 |
|
| 106 |
def init_weights(self, base_std: float, factor: float = 1.0):
|
| 107 |
-
std = base_std
|
| 108 |
|
| 109 |
nn.init.trunc_normal_(
|
| 110 |
self.wq.weight,
|
|
@@ -130,13 +130,12 @@ class CrossAttention(nn.Module):
|
|
| 130 |
b=3 * std,
|
| 131 |
)
|
| 132 |
|
| 133 |
-
output_std = std / (2**0.5)
|
| 134 |
nn.init.trunc_normal_(
|
| 135 |
self.wo.weight,
|
| 136 |
mean=0.0,
|
| 137 |
-
std=
|
| 138 |
-
a=-3 *
|
| 139 |
-
b=3 *
|
| 140 |
)
|
| 141 |
self.cross_attn_norm_q.reset_parameters()
|
| 142 |
self.cross_attn_norm_kv.reset_parameters()
|
|
@@ -147,6 +146,7 @@ class GlobalTransformer(BaseTransformer):
|
|
| 147 |
super().__init__(args)
|
| 148 |
self.dropout = args.dropout
|
| 149 |
self.eos_id = args.eos_id
|
|
|
|
| 150 |
|
| 151 |
self.token_embedding_projection = None
|
| 152 |
if args.dim_token_emb is not None and args.dim_token_emb != self.dim:
|
|
@@ -192,13 +192,14 @@ class GlobalTransformer(BaseTransformer):
|
|
| 192 |
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl)
|
| 193 |
return h, cache
|
| 194 |
|
| 195 |
-
def init_weights(self
|
| 196 |
super().init_weights()
|
|
|
|
| 197 |
if self.token_embedding_projection is not None:
|
| 198 |
nn.init.trunc_normal_(
|
| 199 |
self.token_embedding_projection.weight,
|
| 200 |
mean=0.0,
|
| 201 |
-
std=
|
| 202 |
-
a=-3 *
|
| 203 |
-
b=3 *
|
| 204 |
)
|
|
|
|
| 78 |
# B S D
|
| 79 |
bsz, seq_len, _ = x.shape
|
| 80 |
_, slen_kv, _ = kv.shape
|
| 81 |
+
x_norm = self.cross_attn_norm_q(x)
|
| 82 |
kv = self.cross_attn_norm_kv(kv)
|
| 83 |
|
| 84 |
+
xq = self.wq(x_norm)
|
| 85 |
xk = self.wk(kv)
|
| 86 |
xv = self.wv(kv)
|
| 87 |
|
|
|
|
| 104 |
return x + output
|
| 105 |
|
| 106 |
def init_weights(self, base_std: float, factor: float = 1.0):
|
| 107 |
+
std = base_std or (self.dim ** (-0.5)) / factor
|
| 108 |
|
| 109 |
nn.init.trunc_normal_(
|
| 110 |
self.wq.weight,
|
|
|
|
| 130 |
b=3 * std,
|
| 131 |
)
|
| 132 |
|
|
|
|
| 133 |
nn.init.trunc_normal_(
|
| 134 |
self.wo.weight,
|
| 135 |
mean=0.0,
|
| 136 |
+
std=std,
|
| 137 |
+
a=-3 * std,
|
| 138 |
+
b=3 * std,
|
| 139 |
)
|
| 140 |
self.cross_attn_norm_q.reset_parameters()
|
| 141 |
self.cross_attn_norm_kv.reset_parameters()
|
|
|
|
| 146 |
super().__init__(args)
|
| 147 |
self.dropout = args.dropout
|
| 148 |
self.eos_id = args.eos_id
|
| 149 |
+
self.dim_token_emb = args.dim_token_emb
|
| 150 |
|
| 151 |
self.token_embedding_projection = None
|
| 152 |
if args.dim_token_emb is not None and args.dim_token_emb != self.dim:
|
|
|
|
| 192 |
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl)
|
| 193 |
return h, cache
|
| 194 |
|
| 195 |
+
def init_weights(self):
|
| 196 |
super().init_weights()
|
| 197 |
+
std = self.dim_token_emb ** (-0.5)
|
| 198 |
if self.token_embedding_projection is not None:
|
| 199 |
nn.init.trunc_normal_(
|
| 200 |
self.token_embedding_projection.weight,
|
| 201 |
mean=0.0,
|
| 202 |
+
std=std,
|
| 203 |
+
a=-3 * std,
|
| 204 |
+
b=3 * std,
|
| 205 |
)
|
bytelatent/model/local_models.py
CHANGED
|
@@ -34,7 +34,7 @@ class LocalModelArgs(BaseTransformerArgs):
|
|
| 34 |
# Local encoder specific dimensions
|
| 35 |
dropout: float
|
| 36 |
vocab_size: int
|
| 37 |
-
patch_size:
|
| 38 |
sliding_window: int | None
|
| 39 |
use_rope: bool
|
| 40 |
cross_attn_encoder: bool | None
|
|
@@ -61,6 +61,7 @@ class LocalModelBase(nn.Module):
|
|
| 61 |
self.dropout = args.dropout
|
| 62 |
self.vocab_size = args.vocab_size
|
| 63 |
self.patch_size = args.patch_size
|
|
|
|
| 64 |
|
| 65 |
self.attn_impl = args.attn_impl
|
| 66 |
self.sliding_window = args.sliding_window
|
|
@@ -130,6 +131,7 @@ class LocalModelBase(nn.Module):
|
|
| 130 |
|
| 131 |
def init_weights(self, init_std=None):
|
| 132 |
self.rope.reset_parameters()
|
|
|
|
| 133 |
|
| 134 |
init_std = init_std or (self.dim ** (-0.5))
|
| 135 |
nn.init.trunc_normal_(
|
|
@@ -156,33 +158,34 @@ class LocalModelBase(nn.Module):
|
|
| 156 |
InitStdFactor.DISABLED: 1.0,
|
| 157 |
}[self.init_std_factor]
|
| 158 |
|
| 159 |
-
layer.init_weights(
|
| 160 |
|
| 161 |
-
if self
|
| 162 |
nn.init.trunc_normal_(
|
| 163 |
-
self.
|
| 164 |
mean=0.0,
|
| 165 |
std=init_std,
|
| 166 |
a=-3 * init_std,
|
| 167 |
b=3 * init_std,
|
| 168 |
)
|
| 169 |
|
| 170 |
-
if self.
|
| 171 |
nn.init.trunc_normal_(
|
| 172 |
-
self.
|
| 173 |
mean=0.0,
|
| 174 |
std=init_std,
|
| 175 |
a=-3 * init_std,
|
| 176 |
b=3 * init_std,
|
| 177 |
)
|
| 178 |
|
| 179 |
-
if
|
|
|
|
| 180 |
nn.init.trunc_normal_(
|
| 181 |
-
self.
|
| 182 |
mean=0.0,
|
| 183 |
-
std=
|
| 184 |
-
a=-3 *
|
| 185 |
-
b=3 *
|
| 186 |
)
|
| 187 |
|
| 188 |
if self.cross_attn_layers is not None:
|
|
@@ -194,7 +197,7 @@ class LocalModelBase(nn.Module):
|
|
| 194 |
InitStdFactor.DISABLED: 1.0,
|
| 195 |
}[self.init_std_factor]
|
| 196 |
|
| 197 |
-
layer.init_weights(
|
| 198 |
|
| 199 |
|
| 200 |
class LocalEncoder(LocalModelBase):
|
|
|
|
| 34 |
# Local encoder specific dimensions
|
| 35 |
dropout: float
|
| 36 |
vocab_size: int
|
| 37 |
+
patch_size: float
|
| 38 |
sliding_window: int | None
|
| 39 |
use_rope: bool
|
| 40 |
cross_attn_encoder: bool | None
|
|
|
|
| 61 |
self.dropout = args.dropout
|
| 62 |
self.vocab_size = args.vocab_size
|
| 63 |
self.patch_size = args.patch_size
|
| 64 |
+
self.dim_patch_emb = args.dim_patch_emb
|
| 65 |
|
| 66 |
self.attn_impl = args.attn_impl
|
| 67 |
self.sliding_window = args.sliding_window
|
|
|
|
| 131 |
|
| 132 |
def init_weights(self, init_std=None):
|
| 133 |
self.rope.reset_parameters()
|
| 134 |
+
self.norm.reset_parameters()
|
| 135 |
|
| 136 |
init_std = init_std or (self.dim ** (-0.5))
|
| 137 |
nn.init.trunc_normal_(
|
|
|
|
| 158 |
InitStdFactor.DISABLED: 1.0,
|
| 159 |
}[self.init_std_factor]
|
| 160 |
|
| 161 |
+
layer.init_weights(None, factor)
|
| 162 |
|
| 163 |
+
if hasattr(self, "output"):
|
| 164 |
nn.init.trunc_normal_(
|
| 165 |
+
self.output.weight,
|
| 166 |
mean=0.0,
|
| 167 |
std=init_std,
|
| 168 |
a=-3 * init_std,
|
| 169 |
b=3 * init_std,
|
| 170 |
)
|
| 171 |
|
| 172 |
+
if self.token_embedding_projection is not None:
|
| 173 |
nn.init.trunc_normal_(
|
| 174 |
+
self.token_embedding_projection.weight,
|
| 175 |
mean=0.0,
|
| 176 |
std=init_std,
|
| 177 |
a=-3 * init_std,
|
| 178 |
b=3 * init_std,
|
| 179 |
)
|
| 180 |
|
| 181 |
+
if self.patch_embedding_projection is not None:
|
| 182 |
+
patch_emb_std = self.dim_patch_emb ** (-0.5)
|
| 183 |
nn.init.trunc_normal_(
|
| 184 |
+
self.patch_embedding_projection.weight,
|
| 185 |
mean=0.0,
|
| 186 |
+
std=patch_emb_std,
|
| 187 |
+
a=-3 * patch_emb_std,
|
| 188 |
+
b=3 * patch_emb_std,
|
| 189 |
)
|
| 190 |
|
| 191 |
if self.cross_attn_layers is not None:
|
|
|
|
| 197 |
InitStdFactor.DISABLED: 1.0,
|
| 198 |
}[self.init_std_factor]
|
| 199 |
|
| 200 |
+
layer.init_weights(None, factor)
|
| 201 |
|
| 202 |
|
| 203 |
class LocalEncoder(LocalModelBase):
|
bytelatent/transformer.py
CHANGED
|
@@ -137,14 +137,25 @@ def get_no_recompute_ops():
|
|
| 137 |
def build_fsdp_grouping_plan(model_args: LMTransformerArgs):
|
| 138 |
group_plan: Tuple[int, bool] = []
|
| 139 |
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
return group_plan
|
| 150 |
|
|
|
|
| 137 |
def build_fsdp_grouping_plan(model_args: LMTransformerArgs):
|
| 138 |
group_plan: Tuple[int, bool] = []
|
| 139 |
|
| 140 |
+
if isinstance(model_args, LMTransformerArgs):
|
| 141 |
+
group_plan.append(("tok_embeddings", False))
|
| 142 |
+
|
| 143 |
+
for i in range(model_args.n_layers):
|
| 144 |
+
group_plan.append((f"layers.{i}", False))
|
| 145 |
+
|
| 146 |
+
group_plan.append(("output", True))
|
| 147 |
+
else:
|
| 148 |
+
for i in range(model_args.n_layers_local_encoder):
|
| 149 |
+
group_plan.append((f"local_encoder.layers.{i}", True))
|
| 150 |
+
group_plan.append((f"local_encoder.cross_attn_layers.{i}", True))
|
| 151 |
+
for i in range(model_args.n_layers_local_decoder):
|
| 152 |
+
group_plan.append((f"local_decoder.layers.{i}", True))
|
| 153 |
+
group_plan.append((f"local_decoder.cross_attn_layers.{i}", True))
|
| 154 |
+
for i in range(model_args.n_layers_global):
|
| 155 |
+
group_plan.append((f"global_transformer.layers.{i}", True))
|
| 156 |
+
|
| 157 |
+
for i in range(len(model_args.encoder_hash_byte_group_size)):
|
| 158 |
+
group_plan.append((f"encoder_hash_tok_embedding.{i}", True))
|
| 159 |
|
| 160 |
return group_plan
|
| 161 |
|