Update modeling_gpt_refact.py
Browse files- modeling_gpt_refact.py +92 -76
modeling_gpt_refact.py
CHANGED
|
@@ -21,29 +21,23 @@ logger = logging.get_logger(__name__)
|
|
| 21 |
|
| 22 |
@torch.jit.script
|
| 23 |
def upcast_masked_softmax(
|
| 24 |
-
x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor,
|
| 25 |
):
|
| 26 |
input_dtype = x.dtype
|
| 27 |
-
x = x.to(softmax_dtype)
|
| 28 |
x = torch.where(mask, x, mask_value)
|
| 29 |
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
|
| 30 |
return x
|
| 31 |
|
| 32 |
|
| 33 |
@torch.jit.script
|
| 34 |
-
def upcast_softmax(x: torch.Tensor,
|
| 35 |
input_dtype = x.dtype
|
| 36 |
-
x = x.to(softmax_dtype)
|
| 37 |
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
|
| 38 |
return x
|
| 39 |
|
| 40 |
|
| 41 |
-
@torch.jit.script
|
| 42 |
-
def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
|
| 43 |
-
x = torch.where(mask, x, mask_value)
|
| 44 |
-
x = torch.nn.functional.softmax(x, dim=-1)
|
| 45 |
-
return x
|
| 46 |
-
|
| 47 |
@torch.jit.script
|
| 48 |
def _get_slopes(attn_heads: int, dev: torch.device) -> torch.Tensor:
|
| 49 |
"""
|
|
@@ -76,7 +70,6 @@ def _get_slopes(attn_heads: int, dev: torch.device) -> torch.Tensor:
|
|
| 76 |
m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (attn_heads - n), 2, device=dev))
|
| 77 |
# Concatenate the slopes with the remaining slopes.
|
| 78 |
m = torch.cat([m, m_hat])
|
| 79 |
-
|
| 80 |
return m
|
| 81 |
|
| 82 |
@torch.jit.script
|
|
@@ -85,8 +78,7 @@ def get_alibi_biases(
|
|
| 85 |
T: int,
|
| 86 |
attn_heads: int,
|
| 87 |
dev: torch.device,
|
| 88 |
-
dtype: torch.dtype
|
| 89 |
-
causal: bool = True) -> torch.Tensor:
|
| 90 |
"""
|
| 91 |
## Calculate the attention biases matrix
|
| 92 |
* `n_heads` is the number of heads in the attention layer
|
|
@@ -95,28 +87,25 @@ def get_alibi_biases(
|
|
| 95 |
"""
|
| 96 |
|
| 97 |
# Get slopes $m$ for each head
|
| 98 |
-
|
| 99 |
-
mask = (torch.triu(torch.ones((T, T), device=dev)) == 1).transpose(0, 1)
|
| 100 |
-
else:
|
| 101 |
-
mask = torch.ones((T, T), device=dev, dtype=torch.bool)
|
| 102 |
|
| 103 |
-
m = _get_slopes(attn_heads, dev)
|
| 104 |
|
| 105 |
# Calculate distances $[0, 1, \dots, N]$
|
| 106 |
# Here we calculate the distances using the mask.
|
| 107 |
#
|
| 108 |
# Since it's causal mask we can just use $[0, 1, \dots, N]$ too.
|
| 109 |
# `distance = torch.arange(mask.shape[1], dtype=torch.long, device=mask.device)[None, :]`
|
| 110 |
-
distance = mask.cumsum(dim=-1)
|
| 111 |
|
| 112 |
# Multiply them pair-wise to get the AliBi bias matrix
|
| 113 |
biases = distance[:, :, None] * m[None, None, :]
|
| 114 |
biases = biases.permute(2, 0, 1)[None, :, :T, :T]
|
| 115 |
-
|
| 116 |
-
return biases.to(dtype).contiguous()
|
| 117 |
|
| 118 |
|
| 119 |
class Attention(nn.Module):
|
|
|
|
| 120 |
def __init__(self, config, layer_idx=None):
|
| 121 |
super().__init__()
|
| 122 |
self.mask_value = None
|
|
@@ -126,7 +115,7 @@ class Attention(nn.Module):
|
|
| 126 |
self.head_dim = self.embed_dim // self.num_heads
|
| 127 |
self.kv_attn_heads = 1
|
| 128 |
|
| 129 |
-
self.
|
| 130 |
|
| 131 |
if self.head_dim * self.num_heads != self.embed_dim:
|
| 132 |
raise ValueError(
|
|
@@ -139,41 +128,63 @@ class Attention(nn.Module):
|
|
| 139 |
self.scale_attention_softmax_in_fp32 = (
|
| 140 |
config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
|
| 141 |
)
|
|
|
|
| 142 |
|
| 143 |
self.q = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
| 144 |
-
self.
|
| 145 |
-
self.v = nn.Linear(self.embed_dim, self.head_dim, bias=False)
|
| 146 |
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
| 147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
def _attn(self, query, key, value, attention_mask=None, alibi=None):
|
| 149 |
dtype = query.dtype
|
| 150 |
softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype
|
|
|
|
| 151 |
upcast = dtype != softmax_dtype
|
| 152 |
-
unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
|
| 153 |
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
if upcast:
|
|
|
|
|
|
|
| 157 |
if attention_mask is None:
|
| 158 |
-
attn_weights = upcast_softmax(attn_weights,
|
| 159 |
else:
|
| 160 |
-
|
| 161 |
-
attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype)
|
| 162 |
else:
|
| 163 |
if attention_mask is not None:
|
| 164 |
-
|
| 165 |
-
|
| 166 |
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
| 167 |
|
| 168 |
-
attn_output = torch.
|
| 169 |
|
| 170 |
return attn_output, attn_weights
|
| 171 |
|
| 172 |
-
def _split_heads(self, tensor):
|
| 173 |
-
new_shape = tensor.shape[:-1] + (self.num_heads, self.head_dim)
|
| 174 |
-
tensor = tensor.view(new_shape)
|
| 175 |
-
return tensor.permute(0, 2, 1, 3)
|
| 176 |
-
|
| 177 |
def forward(
|
| 178 |
self,
|
| 179 |
hidden_states: torch.Tensor,
|
|
@@ -186,13 +197,9 @@ class Attention(nn.Module):
|
|
| 186 |
Tuple[torch.Tensor, Optional[torch.Tensor]],
|
| 187 |
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
|
| 188 |
]:
|
| 189 |
-
b, t, _ = hidden_states.shape
|
| 190 |
query = self.q(hidden_states)
|
| 191 |
-
|
| 192 |
-
value = self.
|
| 193 |
-
query = self._split_heads(query)
|
| 194 |
-
key = key.view(b, t, self.kv_attn_heads, self.head_dim).permute(0, 2, 1, 3)
|
| 195 |
-
value = value.view(b, t, self.kv_attn_heads, self.head_dim).permute(0, 2, 1, 3)
|
| 196 |
|
| 197 |
if layer_past is not None:
|
| 198 |
past_key, past_value = layer_past
|
|
@@ -205,32 +212,31 @@ class Attention(nn.Module):
|
|
| 205 |
present = None
|
| 206 |
|
| 207 |
attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, alibi)
|
| 208 |
-
|
| 209 |
-
attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
|
| 210 |
attn_output = self.c_proj(attn_output)
|
| 211 |
|
| 212 |
outputs = (attn_output, present)
|
| 213 |
if output_attentions:
|
|
|
|
| 214 |
outputs += (attn_weights,)
|
| 215 |
|
| 216 |
return outputs # a, present, (attentions)
|
| 217 |
|
| 218 |
|
| 219 |
class MLP(nn.Module):
|
|
|
|
| 220 |
def __init__(self, intermediate_size, config, multiple_of: int = 256):
|
| 221 |
super().__init__()
|
| 222 |
embed_dim = config.hidden_size
|
| 223 |
hidden_dim = intermediate_size
|
| 224 |
hidden_dim = int(2 * hidden_dim / 3)
|
| 225 |
-
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
| 226 |
-
self.
|
| 227 |
-
self.
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
x1 =
|
| 232 |
-
|
| 233 |
-
x = self.c_proj(x1 * x2)
|
| 234 |
return x
|
| 235 |
|
| 236 |
|
|
@@ -255,7 +261,6 @@ class GPTRefactBlock(nn.Module):
|
|
| 255 |
self.ln_1 = LayerNormNoBias(hidden_size, eps=config.layer_norm_epsilon)
|
| 256 |
self.attn = Attention(config, layer_idx=layer_idx)
|
| 257 |
self.ln_2 = LayerNormNoBias(hidden_size, eps=config.layer_norm_epsilon)
|
| 258 |
-
|
| 259 |
self.mlp = MLP(self.inner_dim, config)
|
| 260 |
|
| 261 |
def forward(
|
|
@@ -297,6 +302,7 @@ class GPTRefactBlock(nn.Module):
|
|
| 297 |
|
| 298 |
|
| 299 |
class GPTRefactPreTrainedModel(PreTrainedModel):
|
|
|
|
| 300 |
config_class = GPTRefactConfig
|
| 301 |
base_model_prefix = "transformer"
|
| 302 |
supports_gradient_checkpointing = True
|
|
@@ -331,12 +337,9 @@ class GPTRefactPreTrainedModel(PreTrainedModel):
|
|
| 331 |
elif isinstance(module, LayerNormNoBias):
|
| 332 |
module.weight.data.fill_(1.0)
|
| 333 |
|
| 334 |
-
def _set_gradient_checkpointing(self, module, value=False):
|
| 335 |
-
if isinstance(module, GPTRefactModel):
|
| 336 |
-
module.gradient_checkpointing = value
|
| 337 |
-
|
| 338 |
|
| 339 |
class GPTRefactModel(GPTRefactPreTrainedModel):
|
|
|
|
| 340 |
def __init__(self, config):
|
| 341 |
super().__init__(config)
|
| 342 |
self.embed_dim = config.hidden_size
|
|
@@ -347,6 +350,7 @@ class GPTRefactModel(GPTRefactPreTrainedModel):
|
|
| 347 |
self.h = nn.ModuleList([GPTRefactBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
|
| 348 |
|
| 349 |
self.max_positions = config.max_position_embeddings
|
|
|
|
| 350 |
self.register_buffer(
|
| 351 |
"bias", torch.tril(torch.ones((self.max_positions, self.max_positions), dtype=torch.bool)),
|
| 352 |
persistent=False
|
|
@@ -357,15 +361,8 @@ class GPTRefactModel(GPTRefactPreTrainedModel):
|
|
| 357 |
# Initialize weights and apply final processing
|
| 358 |
self.post_init()
|
| 359 |
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
# prompt
|
| 363 |
-
if past_key_values_length == 0:
|
| 364 |
-
mask = torch.ones((seq_len, seq_len + past_key_values_length), dtype=torch.bool)
|
| 365 |
-
mask = torch.triu(mask, 1)
|
| 366 |
-
else:
|
| 367 |
-
mask = torch.zeros((seq_len, seq_len + past_key_values_length), dtype=torch.bool)
|
| 368 |
-
return mask
|
| 369 |
|
| 370 |
def forward(
|
| 371 |
self,
|
|
@@ -408,19 +405,25 @@ class GPTRefactModel(GPTRefactPreTrainedModel):
|
|
| 408 |
else:
|
| 409 |
past_length = past_key_values[0][0].size(-2)
|
| 410 |
|
| 411 |
-
# Self-attention mask.
|
| 412 |
query_length = input_shape[-1]
|
| 413 |
-
|
| 414 |
seq_length_with_past = past_length + query_length
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
|
| 420 |
hidden_states = self.wte(input_ids) if inputs_embeds is None else inputs_embeds
|
| 421 |
|
|
|
|
| 422 |
alibi = get_alibi_biases(hidden_states.shape[0], seq_length_with_past,
|
| 423 |
-
self.num_heads, device,
|
| 424 |
|
| 425 |
output_shape = input_shape + (hidden_states.size(-1),)
|
| 426 |
|
|
@@ -489,6 +492,7 @@ class GPTRefactModel(GPTRefactPreTrainedModel):
|
|
| 489 |
|
| 490 |
|
| 491 |
class GPTRefactForCausalLM(GPTRefactPreTrainedModel):
|
|
|
|
| 492 |
_tied_weights_keys = ["lm_head.weight", "ln_f.weight"]
|
| 493 |
|
| 494 |
def __init__(self, config):
|
|
@@ -499,6 +503,18 @@ class GPTRefactForCausalLM(GPTRefactPreTrainedModel):
|
|
| 499 |
|
| 500 |
# Initialize weights and apply final processing
|
| 501 |
self.post_init()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 502 |
|
| 503 |
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
| 504 |
if inputs_embeds is not None and past_key_values is None:
|
|
@@ -583,4 +599,4 @@ class GPTRefactForCausalLM(GPTRefactPreTrainedModel):
|
|
| 583 |
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
| 584 |
beam_idx at every generation step.
|
| 585 |
"""
|
| 586 |
-
return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values)
|
|
|
|
| 21 |
|
| 22 |
@torch.jit.script
|
| 23 |
def upcast_masked_softmax(
|
| 24 |
+
x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, softmax_dtype: torch.dtype
|
| 25 |
):
|
| 26 |
input_dtype = x.dtype
|
| 27 |
+
x = x.to(softmax_dtype)
|
| 28 |
x = torch.where(mask, x, mask_value)
|
| 29 |
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
|
| 30 |
return x
|
| 31 |
|
| 32 |
|
| 33 |
@torch.jit.script
|
| 34 |
+
def upcast_softmax(x: torch.Tensor, softmax_dtype: torch.dtype):
|
| 35 |
input_dtype = x.dtype
|
| 36 |
+
x = x.to(softmax_dtype)
|
| 37 |
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
|
| 38 |
return x
|
| 39 |
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
@torch.jit.script
|
| 42 |
def _get_slopes(attn_heads: int, dev: torch.device) -> torch.Tensor:
|
| 43 |
"""
|
|
|
|
| 70 |
m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (attn_heads - n), 2, device=dev))
|
| 71 |
# Concatenate the slopes with the remaining slopes.
|
| 72 |
m = torch.cat([m, m_hat])
|
|
|
|
| 73 |
return m
|
| 74 |
|
| 75 |
@torch.jit.script
|
|
|
|
| 78 |
T: int,
|
| 79 |
attn_heads: int,
|
| 80 |
dev: torch.device,
|
| 81 |
+
dtype: torch.dtype) -> torch.Tensor:
|
|
|
|
| 82 |
"""
|
| 83 |
## Calculate the attention biases matrix
|
| 84 |
* `n_heads` is the number of heads in the attention layer
|
|
|
|
| 87 |
"""
|
| 88 |
|
| 89 |
# Get slopes $m$ for each head
|
| 90 |
+
mask = torch.ones((T, T), device=dev, dtype=torch.bool)
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
+
m = _get_slopes(attn_heads, dev).to(dtype)
|
| 93 |
|
| 94 |
# Calculate distances $[0, 1, \dots, N]$
|
| 95 |
# Here we calculate the distances using the mask.
|
| 96 |
#
|
| 97 |
# Since it's causal mask we can just use $[0, 1, \dots, N]$ too.
|
| 98 |
# `distance = torch.arange(mask.shape[1], dtype=torch.long, device=mask.device)[None, :]`
|
| 99 |
+
distance = mask.cumsum(dim=-1).to(dtype)
|
| 100 |
|
| 101 |
# Multiply them pair-wise to get the AliBi bias matrix
|
| 102 |
biases = distance[:, :, None] * m[None, None, :]
|
| 103 |
biases = biases.permute(2, 0, 1)[None, :, :T, :T]
|
| 104 |
+
return biases.contiguous()
|
|
|
|
| 105 |
|
| 106 |
|
| 107 |
class Attention(nn.Module):
|
| 108 |
+
|
| 109 |
def __init__(self, config, layer_idx=None):
|
| 110 |
super().__init__()
|
| 111 |
self.mask_value = None
|
|
|
|
| 115 |
self.head_dim = self.embed_dim // self.num_heads
|
| 116 |
self.kv_attn_heads = 1
|
| 117 |
|
| 118 |
+
self.scale_factor = self.head_dim ** -0.5
|
| 119 |
|
| 120 |
if self.head_dim * self.num_heads != self.embed_dim:
|
| 121 |
raise ValueError(
|
|
|
|
| 128 |
self.scale_attention_softmax_in_fp32 = (
|
| 129 |
config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
|
| 130 |
)
|
| 131 |
+
self.attention_bias_in_fp32 = config.attention_bias_in_fp32
|
| 132 |
|
| 133 |
self.q = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
| 134 |
+
self.kv = nn.Linear(self.embed_dim, self.head_dim * 2, bias=False)
|
|
|
|
| 135 |
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
| 136 |
|
| 137 |
+
def _get_mask_value(self, device, dtype):
|
| 138 |
+
# torch.where expects a tensor. We use a cache to avoid recreating it every time.
|
| 139 |
+
if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device:
|
| 140 |
+
self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device)
|
| 141 |
+
return self.mask_value
|
| 142 |
+
|
| 143 |
def _attn(self, query, key, value, attention_mask=None, alibi=None):
|
| 144 |
dtype = query.dtype
|
| 145 |
softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype
|
| 146 |
+
mask_value = self._get_mask_value(query.device, softmax_dtype)
|
| 147 |
upcast = dtype != softmax_dtype
|
|
|
|
| 148 |
|
| 149 |
+
query_shape = query.shape
|
| 150 |
+
batch_size = query_shape[0]
|
| 151 |
+
key_length = key.size(-1)
|
| 152 |
+
|
| 153 |
+
# (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length)
|
| 154 |
+
# -> (batch_size, query_length, num_heads, key_length)
|
| 155 |
+
query_length = query_shape[1]
|
| 156 |
+
attn_shape = (batch_size, query_length, self.num_heads, key_length)
|
| 157 |
+
attn_view = (batch_size, query_length * self.num_heads, key_length)
|
| 158 |
+
# No copy needed for MQA 2, or when layer_past is provided.
|
| 159 |
+
query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim)
|
| 160 |
+
|
| 161 |
+
alibi = alibi.transpose(2, 1).reshape(alibi.shape[0], -1, alibi.shape[-1])
|
| 162 |
+
initial_dtype = query.dtype
|
| 163 |
+
new_dtype = torch.float32 if self.attention_bias_in_fp32 else initial_dtype
|
| 164 |
+
attn_weights = alibi.baddbmm(
|
| 165 |
+
batch1=query.to(new_dtype),
|
| 166 |
+
batch2=key.to(new_dtype),
|
| 167 |
+
beta=1,
|
| 168 |
+
alpha=self.scale_factor
|
| 169 |
+
).view(attn_shape).to(initial_dtype)
|
| 170 |
|
| 171 |
if upcast:
|
| 172 |
+
# Use a fused kernel to prevent a large overhead from casting and scaling.
|
| 173 |
+
# Sub-optimal when the key length is not a multiple of 8.
|
| 174 |
if attention_mask is None:
|
| 175 |
+
attn_weights = upcast_softmax(attn_weights, softmax_dtype)
|
| 176 |
else:
|
| 177 |
+
attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, softmax_dtype)
|
|
|
|
| 178 |
else:
|
| 179 |
if attention_mask is not None:
|
| 180 |
+
# The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion.
|
| 181 |
+
attn_weights = torch.where(attention_mask, attn_weights, mask_value)
|
| 182 |
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
| 183 |
|
| 184 |
+
attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape)
|
| 185 |
|
| 186 |
return attn_output, attn_weights
|
| 187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
def forward(
|
| 189 |
self,
|
| 190 |
hidden_states: torch.Tensor,
|
|
|
|
| 197 |
Tuple[torch.Tensor, Optional[torch.Tensor]],
|
| 198 |
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
|
| 199 |
]:
|
|
|
|
| 200 |
query = self.q(hidden_states)
|
| 201 |
+
kv = self.kv(hidden_states)
|
| 202 |
+
key, value = kv.split(self.head_dim, dim=-1)
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
if layer_past is not None:
|
| 205 |
past_key, past_value = layer_past
|
|
|
|
| 212 |
present = None
|
| 213 |
|
| 214 |
attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, alibi)
|
|
|
|
|
|
|
| 215 |
attn_output = self.c_proj(attn_output)
|
| 216 |
|
| 217 |
outputs = (attn_output, present)
|
| 218 |
if output_attentions:
|
| 219 |
+
attn_weights = attn_weights.transpose(1, 2)
|
| 220 |
outputs += (attn_weights,)
|
| 221 |
|
| 222 |
return outputs # a, present, (attentions)
|
| 223 |
|
| 224 |
|
| 225 |
class MLP(nn.Module):
|
| 226 |
+
|
| 227 |
def __init__(self, intermediate_size, config, multiple_of: int = 256):
|
| 228 |
super().__init__()
|
| 229 |
embed_dim = config.hidden_size
|
| 230 |
hidden_dim = intermediate_size
|
| 231 |
hidden_dim = int(2 * hidden_dim / 3)
|
| 232 |
+
self.hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
| 233 |
+
self.gate_up_proj = nn.Linear(embed_dim, self.hidden_dim * 2, bias=False)
|
| 234 |
+
self.c_proj = nn.Linear(self.hidden_dim, embed_dim, bias=False)
|
| 235 |
+
|
| 236 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 237 |
+
up_proj = self.gate_up_proj(x)
|
| 238 |
+
x1, x2 = torch.split(up_proj, self.hidden_dim, dim=-1)
|
| 239 |
+
x = self.c_proj(F.silu(x1) * x2)
|
|
|
|
| 240 |
return x
|
| 241 |
|
| 242 |
|
|
|
|
| 261 |
self.ln_1 = LayerNormNoBias(hidden_size, eps=config.layer_norm_epsilon)
|
| 262 |
self.attn = Attention(config, layer_idx=layer_idx)
|
| 263 |
self.ln_2 = LayerNormNoBias(hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
| 264 |
self.mlp = MLP(self.inner_dim, config)
|
| 265 |
|
| 266 |
def forward(
|
|
|
|
| 302 |
|
| 303 |
|
| 304 |
class GPTRefactPreTrainedModel(PreTrainedModel):
|
| 305 |
+
|
| 306 |
config_class = GPTRefactConfig
|
| 307 |
base_model_prefix = "transformer"
|
| 308 |
supports_gradient_checkpointing = True
|
|
|
|
| 337 |
elif isinstance(module, LayerNormNoBias):
|
| 338 |
module.weight.data.fill_(1.0)
|
| 339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
|
| 341 |
class GPTRefactModel(GPTRefactPreTrainedModel):
|
| 342 |
+
|
| 343 |
def __init__(self, config):
|
| 344 |
super().__init__(config)
|
| 345 |
self.embed_dim = config.hidden_size
|
|
|
|
| 350 |
self.h = nn.ModuleList([GPTRefactBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
|
| 351 |
|
| 352 |
self.max_positions = config.max_position_embeddings
|
| 353 |
+
self.attention_bias_in_fp32 = config.attention_bias_in_fp32
|
| 354 |
self.register_buffer(
|
| 355 |
"bias", torch.tril(torch.ones((self.max_positions, self.max_positions), dtype=torch.bool)),
|
| 356 |
persistent=False
|
|
|
|
| 361 |
# Initialize weights and apply final processing
|
| 362 |
self.post_init()
|
| 363 |
|
| 364 |
+
def get_input_embeddings(self):
|
| 365 |
+
return self.wte
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
|
| 367 |
def forward(
|
| 368 |
self,
|
|
|
|
| 405 |
else:
|
| 406 |
past_length = past_key_values[0][0].size(-2)
|
| 407 |
|
|
|
|
| 408 |
query_length = input_shape[-1]
|
|
|
|
| 409 |
seq_length_with_past = past_length + query_length
|
| 410 |
+
|
| 411 |
+
# Self-attention mask.
|
| 412 |
+
key_length = past_length + query_length
|
| 413 |
+
self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length]
|
| 414 |
+
if attention_mask is not None:
|
| 415 |
+
self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to(
|
| 416 |
+
dtype=torch.bool, device=self_attention_mask.device
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
# MQA models: (batch_size, query_length, n_heads, key_length)
|
| 420 |
+
attention_mask = self_attention_mask.unsqueeze(2)
|
| 421 |
|
| 422 |
hidden_states = self.wte(input_ids) if inputs_embeds is None else inputs_embeds
|
| 423 |
|
| 424 |
+
alibi_dtype = torch.float32 if self.attention_bias_in_fp32 else self.wte.weight.dtype
|
| 425 |
alibi = get_alibi_biases(hidden_states.shape[0], seq_length_with_past,
|
| 426 |
+
self.num_heads, device, alibi_dtype)[:, :, -query_length:, :]
|
| 427 |
|
| 428 |
output_shape = input_shape + (hidden_states.size(-1),)
|
| 429 |
|
|
|
|
| 492 |
|
| 493 |
|
| 494 |
class GPTRefactForCausalLM(GPTRefactPreTrainedModel):
|
| 495 |
+
|
| 496 |
_tied_weights_keys = ["lm_head.weight", "ln_f.weight"]
|
| 497 |
|
| 498 |
def __init__(self, config):
|
|
|
|
| 503 |
|
| 504 |
# Initialize weights and apply final processing
|
| 505 |
self.post_init()
|
| 506 |
+
|
| 507 |
+
# gradient checkpointing support for lower versions of transformers
|
| 508 |
+
import transformers
|
| 509 |
+
from packaging import version
|
| 510 |
+
|
| 511 |
+
def _set_gradient_checkpointing(module, value=False):
|
| 512 |
+
if isinstance(module, GPTRefactModel):
|
| 513 |
+
module.gradient_checkpointing = value
|
| 514 |
+
|
| 515 |
+
v = version.parse(transformers.__version__)
|
| 516 |
+
if v.major <= 4 and v.minor < 35:
|
| 517 |
+
self._set_gradient_checkpointing = _set_gradient_checkpointing
|
| 518 |
|
| 519 |
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
| 520 |
if inputs_embeds is not None and past_key_values is None:
|
|
|
|
| 599 |
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
| 600 |
beam_idx at every generation step.
|
| 601 |
"""
|
| 602 |
+
return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values)
|