nuojohnchen commited on
Commit
a8fea83
·
verified ·
1 Parent(s): 53edf96

Delete modeling_upcycling_qwen2_moe.py

Browse files
Files changed (1) hide show
  1. modeling_upcycling_qwen2_moe.py +0 -1672
modeling_upcycling_qwen2_moe.py DELETED
@@ -1,1672 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
- # and OPT implementations in this library. It has been modified from its
6
- # original forms to accommodate minor architectural differences compared
7
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
- #
9
- # Licensed under the Apache License, Version 2.0 (the "License");
10
- # you may not use this file except in compliance with the License.
11
- # You may obtain a copy of the License at
12
- #
13
- # http://www.apache.org/licenses/LICENSE-2.0
14
- #
15
- # Unless required by applicable law or agreed to in writing, software
16
- # distributed under the License is distributed on an "AS IS" BASIS,
17
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
- # See the License for the specific language governing permissions and
19
- # limitations under the License.
20
- """PyTorch Qwen2MoE model."""
21
-
22
- import inspect
23
- import math
24
- from typing import List, Optional, Tuple, Union
25
-
26
- import torch
27
- import torch.nn.functional as F
28
- import torch.utils.checkpoint
29
- from torch import nn
30
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
-
32
- from transformers.activations import ACT2FN
33
- from transformers.cache_utils import Cache, DynamicCache, StaticCache
34
- from transformers.modeling_attn_mask_utils import (
35
- AttentionMaskConverter,
36
- )
37
- from transformers.modeling_outputs import (
38
- MoeCausalLMOutputWithPast,
39
- MoeModelOutputWithPast,
40
- )
41
- from transformers.modeling_utils import PreTrainedModel
42
- from transformers.utils import (
43
- is_flash_attn_2_available,
44
- is_flash_attn_greater_or_equal_2_10,
45
- logging,
46
- replace_return_docstrings,
47
- )
48
- from transformers_modules.FreedomIntelligence.Apollo-MoE-7B.fdfe57b4c6087b0ec4f6d16fdb6aad6f9ce6fbd7.configuration_upcycling_qwen2_moe import UpcyclingQwen2MoeConfig
49
- from transformers import AutoModelForCausalLM,AutoConfig,AutoModel
50
-
51
-
52
- if is_flash_attn_2_available():
53
- from flash_attn import flash_attn_func, flash_attn_varlen_func
54
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
55
-
56
- _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
57
-
58
- logger = logging.get_logger(__name__)
59
-
60
- _CHECKPOINT_FOR_DOC = "UpcyclingQwen2MoE"
61
- _CONFIG_FOR_DOC = "UpcyclingQwen2MoeConfig"
62
-
63
-
64
- # Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func
65
- def load_balancing_loss_func(
66
- gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
67
- ) -> float:
68
- r"""
69
- Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
70
-
71
- See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
72
- function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
73
- experts is too unbalanced.
74
-
75
- Args:
76
- gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
77
- Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
78
- shape [batch_size X sequence_length, num_experts].
79
- attention_mask (`torch.Tensor`, None):
80
- The attention_mask used in forward function
81
- shape [batch_size X sequence_length] if not None.
82
- num_experts (`int`, *optional*):
83
- Number of experts
84
-
85
- Returns:
86
- The auxiliary loss.
87
- """
88
- if gate_logits is None or not isinstance(gate_logits, tuple):
89
- return 0
90
-
91
- if isinstance(gate_logits, tuple):
92
- compute_device = gate_logits[0].device
93
- concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
94
-
95
- routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
96
-
97
- _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
98
-
99
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
100
-
101
- if attention_mask is None:
102
- # Compute the percentage of tokens routed to each experts
103
- tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
104
-
105
- # Compute the average probability of routing to these experts
106
- router_prob_per_expert = torch.mean(routing_weights, dim=0)
107
- else:
108
- batch_size, sequence_length = attention_mask.shape
109
- num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
110
-
111
- # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
112
- expert_attention_mask = (
113
- attention_mask[None, :, :, None, None]
114
- .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
115
- .reshape(-1, top_k, num_experts)
116
- .to(compute_device)
117
- )
118
-
119
- # Compute the percentage of tokens routed to each experts
120
- tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
121
- expert_attention_mask, dim=0
122
- )
123
-
124
- # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
125
- router_per_expert_attention_mask = (
126
- attention_mask[None, :, :, None]
127
- .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
128
- .reshape(-1, num_experts)
129
- .to(compute_device)
130
- )
131
-
132
- # Compute the average probability of routing to these experts
133
- router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
134
- router_per_expert_attention_mask, dim=0
135
- )
136
-
137
- overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
138
- return overall_loss * num_experts
139
-
140
-
141
- # Copied from transformers.models.llama.modeling_llama._get_unpad_data
142
- def _get_unpad_data(attention_mask):
143
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
144
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
145
- max_seqlen_in_batch = seqlens_in_batch.max().item()
146
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
147
- return (
148
- indices,
149
- cu_seqlens,
150
- max_seqlen_in_batch,
151
- )
152
-
153
-
154
- # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2Moe
155
- class Qwen2MoeRMSNorm(nn.Module):
156
- def __init__(self, hidden_size, eps=1e-6):
157
- """
158
- Qwen2MoeRMSNorm is equivalent to T5LayerNorm
159
- """
160
- super().__init__()
161
- self.weight = nn.Parameter(torch.ones(hidden_size))
162
- self.variance_epsilon = eps
163
-
164
- def forward(self, hidden_states):
165
- input_dtype = hidden_states.dtype
166
- hidden_states = hidden_states.to(torch.float32)
167
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
168
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
169
- return self.weight * hidden_states.to(input_dtype)
170
-
171
-
172
- # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe
173
- class Qwen2MoeRotaryEmbedding(nn.Module):
174
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
175
- super().__init__()
176
-
177
- self.dim = dim
178
- self.max_position_embeddings = max_position_embeddings
179
- self.base = base
180
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
181
- self.register_buffer("inv_freq", inv_freq, persistent=False)
182
-
183
- # Build here to make `torch.jit.trace` work.
184
- self._set_cos_sin_cache(
185
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
186
- )
187
-
188
- def _set_cos_sin_cache(self, seq_len, device, dtype):
189
- self.max_seq_len_cached = seq_len
190
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
191
-
192
- freqs = torch.outer(t, self.inv_freq)
193
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
194
- emb = torch.cat((freqs, freqs), dim=-1)
195
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
196
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
197
-
198
- def forward(self, x, seq_len=None):
199
- # x: [bs, num_attention_heads, seq_len, head_size]
200
- if seq_len > self.max_seq_len_cached:
201
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
202
-
203
- return (
204
- self.cos_cached[:seq_len].to(dtype=x.dtype),
205
- self.sin_cached[:seq_len].to(dtype=x.dtype),
206
- )
207
-
208
-
209
- # Copied from transformers.models.llama.modeling_llama.rotate_half
210
- def rotate_half(x):
211
- """Rotates half the hidden dims of the input."""
212
- x1 = x[..., : x.shape[-1] // 2]
213
- x2 = x[..., x.shape[-1] // 2 :]
214
- return torch.cat((-x2, x1), dim=-1)
215
-
216
-
217
- # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
218
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
219
- """Applies Rotary Position Embedding to the query and key tensors.
220
-
221
- Args:
222
- q (`torch.Tensor`): The query tensor.
223
- k (`torch.Tensor`): The key tensor.
224
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
225
- sin (`torch.Tensor`): The sine part of the rotary embedding.
226
- position_ids (`torch.Tensor`):
227
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
228
- used to pass offsetted position ids when working with a KV-cache.
229
- unsqueeze_dim (`int`, *optional*, defaults to 1):
230
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
231
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
232
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
233
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
234
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
235
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
236
- Returns:
237
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
238
- """
239
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
240
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
241
- q_embed = (q * cos) + (rotate_half(q) * sin)
242
- k_embed = (k * cos) + (rotate_half(k) * sin)
243
- return q_embed, k_embed
244
-
245
-
246
- # Modified from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2Moe
247
- class Qwen2MoeMLP(nn.Module):
248
- def __init__(self, config, intermediate_size=None):
249
- super().__init__()
250
- self.config = config
251
- self.hidden_size = config.hidden_size
252
- self.intermediate_size = intermediate_size
253
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
254
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
255
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
256
- self.act_fn = ACT2FN[config.hidden_act]
257
-
258
- def forward(self, x,language_ids:Optional[torch.LongTensor]=None):
259
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
260
-
261
-
262
- # Copied from transformers.models.llama.modeling_llama.repeat_kv
263
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
264
- """
265
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
266
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
267
- """
268
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
269
- if n_rep == 1:
270
- return hidden_states
271
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
272
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
273
-
274
-
275
- # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2Attention with Qwen2->Qwen2Moe
276
- class Qwen2MoeAttention(nn.Module):
277
- """
278
- Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
279
- and "Generating Long Sequences with Sparse Transformers".
280
- """
281
-
282
- def __init__(self, config: UpcyclingQwen2MoeConfig, layer_idx: Optional[int] = None):
283
- super().__init__()
284
- self.config = config
285
- self.layer_idx = layer_idx
286
- if layer_idx is None:
287
- logger.warning_once(
288
- f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
289
- "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
290
- "when creating this class."
291
- )
292
-
293
- self.hidden_size = config.hidden_size
294
- self.num_heads = config.num_attention_heads
295
- self.head_dim = self.hidden_size // self.num_heads
296
- self.num_key_value_heads = config.num_key_value_heads
297
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
298
- self.max_position_embeddings = config.max_position_embeddings
299
- self.rope_theta = config.rope_theta
300
- self.is_causal = True
301
- self.attention_dropout = config.attention_dropout
302
-
303
- if (self.head_dim * self.num_heads) != self.hidden_size:
304
- raise ValueError(
305
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
306
- f" and `num_heads`: {self.num_heads})."
307
- )
308
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
309
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
310
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
311
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
312
-
313
- self.rotary_emb = Qwen2MoeRotaryEmbedding(
314
- self.head_dim,
315
- max_position_embeddings=self.max_position_embeddings,
316
- base=self.rope_theta,
317
- )
318
-
319
- def forward(
320
- self,
321
- hidden_states: torch.Tensor,
322
- attention_mask: Optional[torch.Tensor] = None,
323
- position_ids: Optional[torch.LongTensor] = None,
324
- past_key_value: Optional[Cache] = None,
325
- output_attentions: bool = False,
326
- use_cache: bool = False,
327
- cache_position: Optional[torch.LongTensor] = None,
328
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
329
- bsz, q_len, _ = hidden_states.size()
330
-
331
- query_states = self.q_proj(hidden_states)
332
- key_states = self.k_proj(hidden_states)
333
- value_states = self.v_proj(hidden_states)
334
-
335
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
336
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
337
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
338
-
339
- kv_seq_len = key_states.shape[-2]
340
- if past_key_value is not None:
341
- if self.layer_idx is None:
342
- raise ValueError(
343
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
344
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
345
- "with a layer index."
346
- )
347
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
348
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
349
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
350
-
351
- if past_key_value is not None:
352
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
353
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
354
-
355
- # repeat k/v heads if n_kv_heads < n_heads
356
- key_states = repeat_kv(key_states, self.num_key_value_groups)
357
- value_states = repeat_kv(value_states, self.num_key_value_groups)
358
-
359
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
360
-
361
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
362
- raise ValueError(
363
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
364
- f" {attn_weights.size()}"
365
- )
366
-
367
- if attention_mask is not None: # no matter the length, we just slice it
368
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
369
- attn_weights = attn_weights + causal_mask
370
-
371
- # upcast attention to fp32
372
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
373
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
374
- attn_output = torch.matmul(attn_weights, value_states)
375
-
376
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
377
- raise ValueError(
378
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
379
- f" {attn_output.size()}"
380
- )
381
-
382
- attn_output = attn_output.transpose(1, 2).contiguous()
383
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
384
-
385
- attn_output = self.o_proj(attn_output)
386
-
387
- if not output_attentions:
388
- attn_weights = None
389
-
390
- return attn_output, attn_weights, past_key_value
391
-
392
-
393
- # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2 with Qwen2->Qwen2Moe
394
- class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
395
- """
396
- Qwen2Moe flash attention module, following Qwen2Moe attention module. This module inherits from `Qwen2MoeAttention`
397
- as the weights of the module stays untouched. The only required change would be on the forward pass
398
- where it needs to correctly call the public API of flash attention and deal with padding tokens
399
- in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
400
- config.max_window_layers layers.
401
- """
402
-
403
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
404
- def __init__(self, *args, **kwargs):
405
- super().__init__(*args, **kwargs)
406
-
407
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
408
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
409
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
410
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
411
-
412
- def forward(
413
- self,
414
- hidden_states: torch.Tensor,
415
- attention_mask: Optional[torch.Tensor] = None,
416
- position_ids: Optional[torch.LongTensor] = None,
417
- past_key_value: Optional[Cache] = None,
418
- output_attentions: bool = False,
419
- use_cache: bool = False,
420
- cache_position: Optional[torch.LongTensor] = None,
421
- ):
422
- bsz, q_len, _ = hidden_states.size()
423
-
424
- query_states = self.q_proj(hidden_states)
425
- key_states = self.k_proj(hidden_states)
426
- value_states = self.v_proj(hidden_states)
427
-
428
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
429
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
430
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
431
-
432
- kv_seq_len = key_states.shape[-2]
433
- if past_key_value is not None:
434
- if self.layer_idx is None:
435
- raise ValueError(
436
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
437
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
438
- "with a layer index."
439
- )
440
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
441
-
442
- # Because the input can be padded, the absolute sequence length depends on the max position id.
443
- rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
444
- cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
445
-
446
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
447
-
448
- use_sliding_windows = (
449
- _flash_supports_window_size
450
- and getattr(self.config, "sliding_window", None) is not None
451
- and kv_seq_len > self.config.sliding_window
452
- and self.config.use_sliding_window
453
- )
454
-
455
- if not _flash_supports_window_size:
456
- logger.warning_once(
457
- "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
458
- " make sure to upgrade flash-attn library."
459
- )
460
-
461
- if past_key_value is not None:
462
- # Activate slicing cache only if the config has a value `sliding_windows` attribute
463
- cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
464
- if (
465
- getattr(self.config, "sliding_window", None) is not None
466
- and kv_seq_len > self.config.sliding_window
467
- and cache_has_contents
468
- ):
469
- slicing_tokens = 1 - self.config.sliding_window
470
-
471
- past_key = past_key_value[self.layer_idx][0]
472
- past_value = past_key_value[self.layer_idx][1]
473
-
474
- past_key = past_key[:, :, slicing_tokens:, :].contiguous()
475
- past_value = past_value[:, :, slicing_tokens:, :].contiguous()
476
-
477
- if past_key.shape[-2] != self.config.sliding_window - 1:
478
- raise ValueError(
479
- f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
480
- f" {past_key.shape}"
481
- )
482
-
483
- if attention_mask is not None:
484
- attention_mask = attention_mask[:, slicing_tokens:]
485
- attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
486
-
487
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
488
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
489
-
490
- # repeat k/v heads if n_kv_heads < n_heads
491
- key_states = repeat_kv(key_states, self.num_key_value_groups)
492
- value_states = repeat_kv(value_states, self.num_key_value_groups)
493
- dropout_rate = 0.0 if not self.training else self.attention_dropout
494
-
495
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
496
- # therefore the input hidden states gets silently casted in float32. Hence, we need
497
- # cast them back in float16 just to be sure everything works as expected.
498
- input_dtype = query_states.dtype
499
- if input_dtype == torch.float32:
500
- if torch.is_autocast_enabled():
501
- target_dtype = torch.get_autocast_gpu_dtype()
502
- # Handle the case where the model is quantized
503
- elif hasattr(self.config, "_pre_quantization_dtype"):
504
- target_dtype = self.config._pre_quantization_dtype
505
- else:
506
- target_dtype = self.q_proj.weight.dtype
507
-
508
- logger.warning_once(
509
- f"The input hidden states seems to be silently casted in float32, this might be related to"
510
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
511
- f" {target_dtype}."
512
- )
513
-
514
- query_states = query_states.to(target_dtype)
515
- key_states = key_states.to(target_dtype)
516
- value_states = value_states.to(target_dtype)
517
-
518
- # Reashape to the expected shape for Flash Attention
519
- query_states = query_states.transpose(1, 2)
520
- key_states = key_states.transpose(1, 2)
521
- value_states = value_states.transpose(1, 2)
522
-
523
- attn_output = self._flash_attention_forward(
524
- query_states,
525
- key_states,
526
- value_states,
527
- attention_mask,
528
- q_len,
529
- dropout=dropout_rate,
530
- use_sliding_windows=use_sliding_windows,
531
- )
532
-
533
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
534
- attn_output = self.o_proj(attn_output)
535
-
536
- if not output_attentions:
537
- attn_weights = None
538
-
539
- return attn_output, attn_weights, past_key_value
540
-
541
- def _flash_attention_forward(
542
- self,
543
- query_states,
544
- key_states,
545
- value_states,
546
- attention_mask,
547
- query_length,
548
- dropout=0.0,
549
- softmax_scale=None,
550
- use_sliding_windows=False,
551
- ):
552
- """
553
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
554
- first unpad the input, then computes the attention scores and pad the final attention scores.
555
-
556
- Args:
557
- query_states (`torch.Tensor`):
558
- Input query states to be passed to Flash Attention API
559
- key_states (`torch.Tensor`):
560
- Input key states to be passed to Flash Attention API
561
- value_states (`torch.Tensor`):
562
- Input value states to be passed to Flash Attention API
563
- attention_mask (`torch.Tensor`):
564
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
565
- position of padding tokens and 1 for the position of non-padding tokens.
566
- dropout (`float`):
567
- Attention dropout
568
- softmax_scale (`float`, *optional*):
569
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
570
- use_sliding_windows (`bool`, *optional*):
571
- Whether to activate sliding window attention.
572
- """
573
- if not self._flash_attn_uses_top_left_mask:
574
- causal = self.is_causal
575
- else:
576
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
577
- causal = self.is_causal and query_length != 1
578
-
579
- # Decide whether to use SWA or not by layer index.
580
- if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
581
- use_sliding_windows = False
582
-
583
- # Contains at least one padding token in the sequence
584
- if attention_mask is not None:
585
- batch_size = query_states.shape[0]
586
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
587
- query_states, key_states, value_states, attention_mask, query_length
588
- )
589
-
590
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
591
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
592
-
593
- if not use_sliding_windows:
594
- attn_output_unpad = flash_attn_varlen_func(
595
- query_states,
596
- key_states,
597
- value_states,
598
- cu_seqlens_q=cu_seqlens_q,
599
- cu_seqlens_k=cu_seqlens_k,
600
- max_seqlen_q=max_seqlen_in_batch_q,
601
- max_seqlen_k=max_seqlen_in_batch_k,
602
- dropout_p=dropout,
603
- softmax_scale=softmax_scale,
604
- causal=causal,
605
- )
606
- else:
607
- attn_output_unpad = flash_attn_varlen_func(
608
- query_states,
609
- key_states,
610
- value_states,
611
- cu_seqlens_q=cu_seqlens_q,
612
- cu_seqlens_k=cu_seqlens_k,
613
- max_seqlen_q=max_seqlen_in_batch_q,
614
- max_seqlen_k=max_seqlen_in_batch_k,
615
- dropout_p=dropout,
616
- softmax_scale=softmax_scale,
617
- causal=causal,
618
- window_size=(self.config.sliding_window, self.config.sliding_window),
619
- )
620
-
621
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
622
- else:
623
- if not use_sliding_windows:
624
- attn_output = flash_attn_func(
625
- query_states,
626
- key_states,
627
- value_states,
628
- dropout,
629
- softmax_scale=softmax_scale,
630
- causal=causal,
631
- )
632
- else:
633
- attn_output = flash_attn_func(
634
- query_states,
635
- key_states,
636
- value_states,
637
- dropout,
638
- softmax_scale=softmax_scale,
639
- causal=causal,
640
- window_size=(self.config.sliding_window, self.config.sliding_window),
641
- )
642
-
643
- return attn_output
644
-
645
- # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input
646
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
647
- batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
648
-
649
- # On the first iteration we need to properly re-create the padding mask
650
- # by slicing it on the proper place
651
- if kv_seq_len != attention_mask.shape[-1]:
652
- attention_mask_num_tokens = attention_mask.shape[-1]
653
- attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
654
-
655
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
656
-
657
- key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
658
- value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
659
-
660
- if query_length == kv_seq_len:
661
- query_layer = index_first_axis(
662
- query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
663
- )
664
- cu_seqlens_q = cu_seqlens_k
665
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
666
- indices_q = indices_k
667
- elif query_length == 1:
668
- max_seqlen_in_batch_q = 1
669
- cu_seqlens_q = torch.arange(
670
- batch_size + 1, dtype=torch.int32, device=query_layer.device
671
- ) # There is a memcpy here, that is very bad.
672
- indices_q = cu_seqlens_q[:-1]
673
- query_layer = query_layer.squeeze(1)
674
- else:
675
- # The -q_len: slice assumes left padding.
676
- attention_mask = attention_mask[:, -query_length:]
677
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
678
-
679
- return (
680
- query_layer,
681
- key_layer,
682
- value_layer,
683
- indices_q,
684
- (cu_seqlens_q, cu_seqlens_k),
685
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
686
- )
687
-
688
-
689
- # Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Qwen2Moe
690
- class Qwen2MoeSdpaAttention(Qwen2MoeAttention):
691
- """
692
- Qwen2Moe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
693
- `Qwen2MoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
694
- SDPA API.
695
- """
696
-
697
- # Adapted from Qwen2MoeAttention.forward
698
- def forward(
699
- self,
700
- hidden_states: torch.Tensor,
701
- attention_mask: Optional[torch.Tensor] = None,
702
- position_ids: Optional[torch.LongTensor] = None,
703
- past_key_value: Optional[Cache] = None,
704
- output_attentions: bool = False,
705
- use_cache: bool = False,
706
- cache_position: Optional[torch.LongTensor] = None,
707
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
708
- if output_attentions:
709
- # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
710
- logger.warning_once(
711
- "Qwen2MoeModel is using Qwen2MoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
712
- 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
713
- )
714
- return super().forward(
715
- hidden_states=hidden_states,
716
- attention_mask=attention_mask,
717
- position_ids=position_ids,
718
- past_key_value=past_key_value,
719
- output_attentions=output_attentions,
720
- use_cache=use_cache,
721
- )
722
-
723
- bsz, q_len, _ = hidden_states.size()
724
-
725
- query_states = self.q_proj(hidden_states)
726
- key_states = self.k_proj(hidden_states)
727
- value_states = self.v_proj(hidden_states)
728
-
729
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
730
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
731
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
732
-
733
- kv_seq_len = key_states.shape[-2]
734
- if past_key_value is not None:
735
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
736
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
737
-
738
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
739
-
740
- if past_key_value is not None:
741
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
742
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
743
-
744
- key_states = repeat_kv(key_states, self.num_key_value_groups)
745
- value_states = repeat_kv(value_states, self.num_key_value_groups)
746
-
747
- causal_mask = attention_mask
748
- if attention_mask is not None: # no matter the length, we just slice it
749
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
750
-
751
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
752
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
753
- if query_states.device.type == "cuda" and attention_mask is not None:
754
- query_states = query_states.contiguous()
755
- key_states = key_states.contiguous()
756
- value_states = value_states.contiguous()
757
-
758
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
759
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
760
- # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
761
- is_causal = True if causal_mask is None and q_len > 1 else False
762
-
763
- attn_output = torch.nn.functional.scaled_dot_product_attention(
764
- query_states,
765
- key_states,
766
- value_states,
767
- attn_mask=causal_mask,
768
- dropout_p=self.attention_dropout if self.training else 0.0,
769
- is_causal=is_causal,
770
- )
771
-
772
- attn_output = attn_output.transpose(1, 2).contiguous()
773
- attn_output = attn_output.view(bsz, q_len, self.hidden_size)
774
-
775
- attn_output = self.o_proj(attn_output)
776
-
777
- return attn_output, None, past_key_value
778
-
779
-
780
- QWEN2MOE_ATTENTION_CLASSES = {
781
- "eager": Qwen2MoeAttention,
782
- "flash_attention_2": Qwen2MoeFlashAttention2,
783
- "sdpa": Qwen2MoeSdpaAttention,
784
- }
785
-
786
-
787
- class Qwen2MoeSparseMoeBlock(nn.Module):
788
- def __init__(self, config):
789
- super().__init__()
790
- self.num_experts = config.num_experts
791
- self.top_k = config.num_experts_per_tok
792
- self.norm_topk_prob = config.norm_topk_prob
793
-
794
- # gating
795
- self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
796
- self.experts = nn.ModuleList(
797
- [Qwen2MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]
798
- )
799
- #share
800
- self.share_flag=config.share_flag
801
-
802
- if self.share_flag:
803
- self.shared_expert = Qwen2MoeMLP(config, intermediate_size=config.shared_expert_intermediate_size)
804
- self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
805
-
806
- #language-specific
807
- self.language_gate=config.language_gate
808
-
809
- def forward(self, hidden_states: torch.Tensor,language_ids:Optional[torch.LongTensor] = None) -> torch.Tensor:
810
-
811
- batch_size, sequence_length, hidden_dim = hidden_states.shape
812
- hidden_states = hidden_states.view(-1, hidden_dim)
813
- if self.language_gate and self.training :
814
- if language_ids is None:
815
- raise ValueError('language_ids is not initialized')
816
- language_ids=language_ids.view(batch_size*sequence_length,-1)
817
- # router_logits: (batch * sequence_length, n_experts)
818
- router_logits = self.gate(hidden_states)
819
-
820
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
821
-
822
- _, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
823
-
824
- #language specific select one expert
825
- if self.language_gate and self.training:
826
- if language_ids is None:
827
- raise ValueError('language_ids is not initialized')
828
- assert language_ids.shape[0]==selected_experts.shape[0],f'{language_ids.shape},{selected_experts.shape}'
829
- language_experts=language_ids.to(selected_experts.dtype)
830
- mask=torch.sum((language_experts==selected_experts).int(),dim=1,keepdims=True).bool()
831
- selected_experts[:,-1]=torch.where(mask.squeeze(),selected_experts[:,-1],language_experts.squeeze())
832
- routing_weights=torch.gather(routing_weights,1,selected_experts)
833
- else:
834
- routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
835
-
836
- if self.norm_topk_prob:
837
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
838
- # we cast back to the input dtype
839
- routing_weights = routing_weights.to(hidden_states.dtype)
840
-
841
- final_hidden_states = torch.zeros(
842
- (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
843
- )
844
-
845
- # One hot encode the selected experts to create an expert mask
846
- # this will be used to easily index which expert is going to be sollicitated
847
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
848
-
849
- # Loop over all available experts in the model and perform the computation on each expert
850
- for expert_idx in range(self.num_experts):
851
- expert_layer = self.experts[expert_idx]
852
- idx, top_x = torch.where(expert_mask[expert_idx])
853
-
854
- # Index the correct hidden states and compute the expert hidden state for
855
- # the current expert. We need to make sure to multiply the output hidden
856
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
857
- current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
858
- current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
859
-
860
- # However `index_add_` only support torch tensors for indexing so we'll use
861
- # the `top_x` tensor here.
862
- final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
863
-
864
- if self.share_flag:
865
-
866
- shared_expert_output = self.shared_expert(hidden_states)
867
- shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output
868
-
869
- final_hidden_states = final_hidden_states + shared_expert_output
870
-
871
- final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
872
- return final_hidden_states, router_logits
873
-
874
-
875
- class Qwen2MoeDecoderLayer(nn.Module):
876
- def __init__(self, config: UpcyclingQwen2MoeConfig, layer_idx: int):
877
- super().__init__()
878
- self.hidden_size = config.hidden_size
879
-
880
- self.self_attn = QWEN2MOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
881
-
882
- if (layer_idx not in config.mlp_only_layers) and (
883
- config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
884
- ):
885
- self.mlp = Qwen2MoeSparseMoeBlock(config)
886
- else:
887
- self.mlp = Qwen2MoeMLP(config, intermediate_size=config.intermediate_size)
888
-
889
- self.input_layernorm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
890
- self.post_attention_layernorm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
891
-
892
- def forward(
893
- self,
894
- hidden_states: torch.Tensor,
895
- language_ids:Optional[torch.LongTensor] = None,
896
- attention_mask: Optional[torch.Tensor] = None,
897
- position_ids: Optional[torch.LongTensor] = None,
898
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
899
- output_attentions: Optional[bool] = False,
900
- output_router_logits: Optional[bool] = False,
901
- use_cache: Optional[bool] = False,
902
- cache_position: Optional[torch.LongTensor] = None,
903
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
904
- """
905
- Args:
906
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
907
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
908
- `(batch, sequence_length)` where padding elements are indicated by 0.
909
- output_attentions (`bool`, *optional*):
910
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
911
- returned tensors for more detail.
912
- output_router_logits (`bool`, *optional*):
913
- Whether or not to return the logits of all the routers. They are useful for computing the router loss,
914
- and should not be returned during inference.
915
- use_cache (`bool`, *optional*):
916
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
917
- (see `past_key_values`).
918
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
919
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
920
- Indices depicting the position of the input sequence tokens in the sequence.
921
- """
922
-
923
- residual = hidden_states
924
-
925
- hidden_states = self.input_layernorm(hidden_states)
926
-
927
- # Self Attention
928
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
929
- hidden_states=hidden_states,
930
- attention_mask=attention_mask,
931
- position_ids=position_ids,
932
- past_key_value=past_key_value,
933
- output_attentions=output_attentions,
934
- use_cache=use_cache,
935
- cache_position=cache_position,
936
- )
937
- hidden_states = residual + hidden_states
938
-
939
- # Fully Connected
940
- residual = hidden_states
941
- hidden_states = self.post_attention_layernorm(hidden_states)
942
-
943
- hidden_states = self.mlp(hidden_states,language_ids)
944
- if isinstance(hidden_states, tuple):
945
- hidden_states, router_logits = hidden_states
946
- else:
947
- router_logits = None
948
-
949
- hidden_states = residual + hidden_states
950
-
951
- outputs = (hidden_states,)
952
-
953
- if output_attentions:
954
- outputs += (self_attn_weights,)
955
-
956
- if use_cache:
957
- outputs += (present_key_value,)
958
-
959
- if output_router_logits:
960
- outputs += (router_logits,)
961
-
962
- return outputs
963
-
964
-
965
- class UpcyclingQwen2MoePreTrainedModel(PreTrainedModel):
966
- config_class = UpcyclingQwen2MoeConfig
967
- base_model_prefix = "model"
968
- supports_gradient_checkpointing = True
969
- _no_split_modules = ["Qwen2MoeDecoderLayer"]
970
- _skip_keys_device_placement = "past_key_values"
971
- _supports_flash_attn_2 = True
972
- _supports_sdpa = True
973
- _supports_cache_class = True
974
-
975
- def _init_weights(self, module):
976
- std = self.config.initializer_range
977
- if isinstance(module, nn.Linear):
978
- module.weight.data.normal_(mean=0.0, std=std)
979
- if module.bias is not None:
980
- module.bias.data.zero_()
981
- elif isinstance(module, nn.Embedding):
982
- module.weight.data.normal_(mean=0.0, std=std)
983
- if module.padding_idx is not None:
984
- module.weight.data[module.padding_idx].zero_()
985
-
986
- @classmethod
987
- def from_qwen(cls, pretrained_model_name_or_path, *model_args, **kwargs):
988
- share_flag=kwargs.pop('share_flag')
989
- attn_init_change=kwargs.pop('attn_init_change')
990
- language_gate=kwargs.pop('language_gate')
991
-
992
- config = cls.config_class.from_pretrained(pretrained_model_name_or_path)
993
-
994
- config.share_flag=True if isinstance(share_flag,bool) and share_flag else False
995
- config.attn_init_change=True if isinstance(attn_init_change,bool) and attn_init_change else False
996
- config.language_gate=True if isinstance(language_gate,bool) and language_gate else False
997
-
998
- print('share_flag',config.share_flag)
999
- print('attn_init_change',config.attn_init_change)
1000
- print('language_gate',config.language_gate)
1001
-
1002
- config.num_experts_per_tok = config.num_experts_per_tok if not config.share_flag else config.num_experts_per_tok-1
1003
- config.num_experts = config.num_experts if not config.share_flag else config.num_experts-1
1004
-
1005
- base_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
1006
- base_cls = type(base_model)
1007
-
1008
- print(cls.config_class,cls)
1009
-
1010
- #create auto_map
1011
- #allows you to use your custom model with the auto-API (but doesn’t share any custom code with other users).
1012
- cls.config_class.register_for_auto_class()
1013
- cls.register_for_auto_class('AutoModelForCausalLM')
1014
-
1015
- # assert base_cls.__name__ == "Qwen2ForCausalLM", f"Invalid convert base model type: {base_cls}"
1016
-
1017
- model = cls(config)
1018
- print(f"converting {base_cls.__name__} to {cls.__name__}")
1019
-
1020
- #MoE architechture
1021
- model_dict=model.state_dict()
1022
- base_model_dict = base_model.state_dict()
1023
-
1024
- #lm_head
1025
- print('lm_head.weight',model_dict['lm_head.weight'],base_model_dict['lm_head.weight'])
1026
-
1027
- shared_keys=set(model_dict)&set(base_model_dict)
1028
- init_keys=[]
1029
- #attention
1030
- for k in shared_keys:
1031
- if k not in init_keys and 'self_attn' in k:
1032
- init_keys.append(k)
1033
- if not config.attn_init_change:
1034
- model_dict[k]=base_model_dict[k]
1035
-
1036
- if config.attn_init_change:
1037
- #initilization with upper and lower
1038
- for layer_id in range(config.num_hidden_layers):
1039
- if layer_id ==0 or config.num_hidden_layers-1:
1040
- model_dict[f'model.layers.{layer_id}.self_attn.q_proj.bias']=base_model_dict[f'model.layers.{layer_id}.self_attn.q_proj.bias']
1041
- model_dict[f'model.layers.{layer_id}.self_attn.q_proj.weight']=base_model_dict[f'model.layers.{layer_id}.self_attn.q_proj.weight']
1042
- model_dict[f'model.layers.{layer_id}.self_attn.k_proj.bias']=base_model_dict[f'model.layers.{layer_id}.self_attn.k_proj.bias']
1043
- model_dict[f'model.layers.{layer_id}.self_attn.k_proj.weight']=base_model_dict[f'model.layers.{layer_id}.self_attn.k_proj.weight']
1044
- model_dict[f'model.layers.{layer_id}.self_attn.v_proj.bias']=base_model_dict[f'model.layers.{layer_id}.self_attn.v_proj.bias']
1045
- model_dict[f'model.layers.{layer_id}.self_attn.v_proj.weight']=base_model_dict[f'model.layers.{layer_id}.self_attn.v_proj.weight']
1046
- model_dict[f'model.layers.{layer_id}.self_attn.o_proj.weight']=base_model_dict[f'model.layers.{layer_id}.self_attn.o_proj.weight']
1047
- else:
1048
- model_dict[f'model.layers.{layer_id}.self_attn.q_proj.bias']=1/3*(base_model_dict[f'model.layers.{layer_id}.self_attn.q_proj.bias']+base_model_dict[f'model.layers.{layer_id+1}.self_attn.q_proj.bias']+base_model_dict[f'model.layers.{layer_id-1}.self_attn.q_proj.bias'])
1049
- model_dict[f'model.layers.{layer_id}.self_attn.q_proj.weight']=1/3*(base_model_dict[f'model.layers.{layer_id}.self_attn.q_proj.weight']+base_model_dict[f'model.layers.{layer_id+1}.self_attn.q_proj.weight']+base_model_dict[f'model.layers.{layer_id-1}.self_attn.q_proj.weight'])
1050
- model_dict[f'model.layers.{layer_id}.self_attn.k_proj.bias']=1/3*(base_model_dict[f'model.layers.{layer_id}.self_attn.k_proj.bias']+base_model_dict[f'model.layers.{layer_id+1}.self_attn.k_proj.bias']+base_model_dict[f'model.layers.{layer_id-1}.self_attn.k_proj.bias'])
1051
- model_dict[f'model.layers.{layer_id}.self_attn.k_proj.weight']=1/3*(base_model_dict[f'model.layers.{layer_id}.self_attn.k_proj.weight']+base_model_dict[f'model.layers.{layer_id+1}.self_attn.k_proj.weight']+base_model_dict[f'model.layers.{layer_id-1}.self_attn.k_proj.weight'])
1052
- model_dict[f'model.layers.{layer_id}.self_attn.v_proj.bias']=1/3*(base_model_dict[f'model.layers.{layer_id}.self_attn.v_proj.bias']+base_model_dict[f'model.layers.{layer_id+1}.self_attn.v_proj.bias']+base_model_dict[f'model.layers.{layer_id-1}.self_attn.v_proj.bias'])
1053
- model_dict[f'model.layers.{layer_id}.self_attn.v_proj.weight']=1/3*(base_model_dict[f'model.layers.{layer_id}.self_attn.v_proj.weight']+base_model_dict[f'model.layers.{layer_id+1}.self_attn.v_proj.weight']+base_model_dict[f'model.layers.{layer_id-1}.self_attn.v_proj.weight'])
1054
- model_dict[f'model.layers.{layer_id}.self_attn.o_proj.weight']=1/3*(base_model_dict[f'model.layers.{layer_id}.self_attn.o_proj.weight']+base_model_dict[f'model.layers.{layer_id+1}.self_attn.o_proj.weight']+base_model_dict[f'model.layers.{layer_id-1}.self_attn.o_proj.weight'])
1055
-
1056
- #mlp
1057
- if config.mlp_only_layers:
1058
- for layer_id in config.mlp_only_layers:
1059
- key_mapping=sum([
1060
- [
1061
- (f'model.layers.{layer_id}.mlp.down_proj.weight',f'model.layers.{layer_id}.mlp.down_proj.weight'),
1062
- (f'model.layers.{layer_id}.mlp.gate_proj.weight',f'model.layers.{layer_id}.mlp.gate_proj.weight'),
1063
- (f'model.layers.{layer_id}.mlp.up_proj.weight',f'model.layers.{layer_id}.mlp.up_proj.weight'),
1064
- ]]
1065
- ,[])
1066
- for model_key,base_model_key in key_mapping:
1067
- model_dict[model_key]=base_model_dict[base_model_key]
1068
- init_keys.append(model_key)
1069
- moe_only_layers=list(set(range(config.num_hidden_layers))-set(config.mlp_only_layers)) if config.mlp_only_layers else config.num_hidden_layers
1070
- #moe-mlp-expert
1071
- for layer_id in moe_only_layers:
1072
- key_mapping=sum([
1073
- [
1074
- (f'model.layers.{layer_id}.mlp.experts.{expert_id}.down_proj.weight',f'model.layers.{layer_id}.mlp.down_proj.weight'),
1075
- (f'model.layers.{layer_id}.mlp.experts.{expert_id}.gate_proj.weight',f'model.layers.{layer_id}.mlp.gate_proj.weight'),
1076
- (f'model.layers.{layer_id}.mlp.experts.{expert_id}.up_proj.weight',f'model.layers.{layer_id}.mlp.up_proj.weight'),
1077
- ] for expert_id in range(config.num_experts)]
1078
- ,[])
1079
- for model_key,base_model_key in key_mapping:
1080
- model_dict[model_key]=base_model_dict[base_model_key]
1081
- init_keys.append(model_key)
1082
- #model_dict[f'model.layers.{layer_id}.mlp.gate.weight']
1083
-
1084
- #share expert
1085
- if config.share_flag:
1086
- shared_key_mapping=sum([[
1087
- (f'model.layers.{layer_id}.mlp.shared_expert.down_proj.weight',f'model.layers.{layer_id}.mlp.down_proj.weight'),
1088
- (f'model.layers.{layer_id}.mlp.shared_expert.gate_proj.weight',f'model.layers.{layer_id}.mlp.gate_proj.weight'),
1089
- (f'model.layers.{layer_id}.mlp.shared_expert.up_proj.weight',f'model.layers.{layer_id}.mlp.up_proj.weight'),
1090
- ]for layer_id in range(config.num_hidden_layers)],
1091
- [])
1092
- for model_key,base_model_key in shared_key_mapping:
1093
- model_dict[model_key]=base_model_dict[base_model_key]
1094
- init_keys.append(model_key)
1095
- # model_dict[f'model.layers.{layer_id}.mlp.shared_expert_gate.weight']
1096
-
1097
- #norm
1098
- for k in shared_keys:
1099
- if k not in init_keys:
1100
- #input_layernorm.weight,post_attention_layernorm.weight,norm.weight
1101
- # embed_token.weight,lm_head.weight
1102
- model_dict[k]=base_model_dict[k]
1103
- init_keys.append(k)
1104
-
1105
- gate_initialized = False
1106
- shared_gate_initilizaed=False
1107
- for key in model_dict.keys():
1108
- if key in init_keys:
1109
- continue
1110
- if "mlp.gate.weight" in key:
1111
- if gate_initialized:
1112
- continue
1113
- gate_initialized = True
1114
- print(f"{cls.__name__} key [{cls.base_model_prefix}.layers.[0-{config.num_hidden_layers-1}].mlp.gate.weight] is not initialized from {base_cls.__name__}. e.g, {key}")
1115
- continue
1116
- if 'shared_expert_gate.weight' in key:
1117
- if shared_gate_initilizaed:
1118
- continue
1119
- shared_gate_initilizaed = True
1120
- print(f"{cls.__name__} key [{cls.base_model_prefix}.layers.[0-{config.num_hidden_layers-1}].mlp.shared_expert_gate.weight] is not initialized from {base_cls.__name__}. e.g, {key}")
1121
- continue
1122
-
1123
- raise NotImplementedError(f"{cls.__name__} key [{key}] is not correctly initilized from {base_cls.__name__}.")
1124
-
1125
- model.load_state_dict(model_dict)
1126
- print(f"Done converted, alreadly check all parameters of {cls.__name__} are initialized from {base_cls.__name__}.")
1127
-
1128
- del base_model
1129
- return model
1130
-
1131
- @classmethod
1132
- def from_btx(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1133
- share_flag=kwargs.pop('share_flag')
1134
- attn_init_change=kwargs.pop('attn_init_change')
1135
- language_gate=kwargs.pop('language_gate')
1136
-
1137
- config = cls.config_class.from_pretrained(pretrained_model_name_or_path)
1138
-
1139
- config.share_flag=True if isinstance(share_flag,bool) and share_flag else False
1140
- config.attn_init_change=True if isinstance(attn_init_change,bool) and attn_init_change else False
1141
- config.language_gate=True if isinstance(language_gate,bool) and language_gate else False
1142
-
1143
- print('share_flag',config.share_flag)
1144
- print('attn_init_change',config.attn_init_change)
1145
- print('language_gate',config.language_gate)
1146
-
1147
- config.num_experts_per_tok = config.num_experts_per_tok if not config.share_flag else config.num_experts_per_tok-1
1148
- config.num_experts = config.num_experts if not config.share_flag else config.num_experts-1
1149
-
1150
- base_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
1151
- base_cls = type(base_model)
1152
-
1153
- print(cls.config_class,cls)
1154
-
1155
- #create auto_map
1156
- #allows you to use your custom model with the auto-API (but doesn’t share any custom code with other users).
1157
- cls.config_class.register_for_auto_class()
1158
- cls.register_for_auto_class('AutoModelForCausalLM')
1159
-
1160
- # assert base_cls.__name__ == "Qwen2ForCausalLM", f"Invalid convert base model type: {base_cls}"
1161
-
1162
- model = cls(config)
1163
- print(f"converting {base_cls.__name__} to {cls.__name__}")
1164
-
1165
- #MoE architechture
1166
- model_dict=model.state_dict()
1167
- base_model_dict = base_model.state_dict()
1168
-
1169
- #lm_head
1170
- print('lm_head.weight',model_dict['lm_head.weight'],base_model_dict['lm_head.weight'])
1171
-
1172
- shared_keys=set(model_dict)&set(base_model_dict)
1173
- init_keys=[]
1174
- #attention
1175
- for k in shared_keys:
1176
- init_keys.append(k)
1177
- model_dict[k]=base_model_dict[k]
1178
-
1179
- gate_initialized = False
1180
- shared_gate_initilizaed=False
1181
- for key in model_dict.keys():
1182
- if key in init_keys:
1183
- continue
1184
- if "mlp.gate.weight" in key:
1185
- if gate_initialized:
1186
- continue
1187
- gate_initialized = True
1188
- print(f"{cls.__name__} key [{cls.base_model_prefix}.layers.[0-{config.num_hidden_layers-1}].mlp.gate.weight] is not initialized from {base_cls.__name__}. e.g, {key}")
1189
- continue
1190
- if 'shared_expert_gate.weight' in key:
1191
- if shared_gate_initilizaed:
1192
- continue
1193
- shared_gate_initilizaed = True
1194
- print(f"{cls.__name__} key [{cls.base_model_prefix}.layers.[0-{config.num_hidden_layers-1}].mlp.shared_expert_gate.weight] is not initialized from {base_cls.__name__}. e.g, {key}")
1195
- continue
1196
-
1197
- raise NotImplementedError(f"{cls.__name__} key [{key}] is not correctly initilized from {base_cls.__name__}.")
1198
-
1199
- model.load_state_dict(model_dict)
1200
- print(f"Done converted, alreadly check all parameters of {cls.__name__} are initialized from {base_cls.__name__}.")
1201
-
1202
- del base_model
1203
- return model
1204
-
1205
- class UpcyclingQwen2MoeModel(UpcyclingQwen2MoePreTrainedModel):
1206
- """
1207
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2MoeDecoderLayer`]
1208
-
1209
- Args:
1210
- config: Qwen2MoeConfig
1211
- """
1212
-
1213
- def __init__(self, config: UpcyclingQwen2MoeConfig):
1214
- super().__init__(config)
1215
- self.padding_idx = config.pad_token_id
1216
- self.vocab_size = config.vocab_size
1217
-
1218
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1219
- self.layers = nn.ModuleList(
1220
- [Qwen2MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1221
- )
1222
- self._attn_implementation = config._attn_implementation
1223
- self.norm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1224
-
1225
- self.gradient_checkpointing = False
1226
- # Initialize weights and apply final processing
1227
- self.post_init()
1228
-
1229
- def get_input_embeddings(self):
1230
- return self.embed_tokens
1231
-
1232
- def set_input_embeddings(self, value):
1233
- self.embed_tokens = value
1234
-
1235
- def forward(
1236
- self,
1237
- input_ids: torch.LongTensor = None,
1238
- language_ids :Optional[torch.LongTensor]= None,
1239
- attention_mask: Optional[torch.Tensor] = None,
1240
- position_ids: Optional[torch.LongTensor] = None,
1241
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1242
- inputs_embeds: Optional[torch.FloatTensor] = None,
1243
- use_cache: Optional[bool] = None,
1244
- output_attentions: Optional[bool] = None,
1245
- output_hidden_states: Optional[bool] = None,
1246
- output_router_logits: Optional[bool] = None,
1247
- return_dict: Optional[bool] = None,
1248
- cache_position: Optional[torch.LongTensor] = None,
1249
- ) -> Union[Tuple, MoeModelOutputWithPast]:
1250
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1251
- output_router_logits = (
1252
- output_router_logits if output_router_logits is not None else self.config.output_router_logits
1253
- )
1254
- output_hidden_states = (
1255
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1256
- )
1257
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1258
-
1259
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1260
-
1261
- if (input_ids is None) ^ (inputs_embeds is not None):
1262
- raise ValueError(
1263
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
1264
- )
1265
-
1266
- if self.gradient_checkpointing and self.training:
1267
- if use_cache:
1268
- logger.warning_once(
1269
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1270
- )
1271
- use_cache = False
1272
-
1273
- use_legacy_cache = False
1274
- if use_cache and not isinstance(past_key_values, Cache):
1275
- use_legacy_cache = True
1276
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1277
- logger.warning_once(
1278
- "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
1279
- "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
1280
- )
1281
-
1282
- if inputs_embeds is None:
1283
- inputs_embeds = self.embed_tokens(input_ids)
1284
-
1285
- if cache_position is None:
1286
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1287
- cache_position = torch.arange(
1288
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1289
- )
1290
- if position_ids is None:
1291
- position_ids = cache_position.unsqueeze(0)
1292
-
1293
- causal_mask = self._update_causal_mask(
1294
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
1295
- )
1296
-
1297
- hidden_states = inputs_embeds
1298
-
1299
- # decoder layers
1300
- all_hidden_states = () if output_hidden_states else None
1301
- all_self_attns = () if output_attentions else None
1302
- all_router_logits = () if output_router_logits else None
1303
- next_decoder_cache = None
1304
-
1305
- for decoder_layer in self.layers:
1306
- if output_hidden_states:
1307
- all_hidden_states += (hidden_states,)
1308
-
1309
- if self.gradient_checkpointing and self.training:
1310
- layer_outputs = self._gradient_checkpointing_func(
1311
- decoder_layer.__call__,
1312
- hidden_states,
1313
- language_ids,
1314
- causal_mask,
1315
- position_ids,
1316
- past_key_values,
1317
- output_attentions,
1318
- output_router_logits,
1319
- use_cache,
1320
- cache_position,
1321
- )
1322
- else:
1323
- layer_outputs = decoder_layer(
1324
- hidden_states,
1325
- language_ids,
1326
- attention_mask=causal_mask,
1327
- position_ids=position_ids,
1328
- past_key_value=past_key_values,
1329
- output_attentions=output_attentions,
1330
- output_router_logits=output_router_logits,
1331
- use_cache=use_cache,
1332
- cache_position=cache_position,
1333
- )
1334
-
1335
- hidden_states = layer_outputs[0]
1336
-
1337
- if use_cache:
1338
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1339
-
1340
- if output_attentions:
1341
- all_self_attns += (layer_outputs[1],)
1342
-
1343
- if output_router_logits and layer_outputs[-1] is not None:
1344
- all_router_logits += (layer_outputs[-1],)
1345
-
1346
- hidden_states = self.norm(hidden_states)
1347
-
1348
- # add hidden states from the last decoder layer
1349
- if output_hidden_states:
1350
- all_hidden_states += (hidden_states,)
1351
-
1352
- next_cache = None
1353
- if use_cache:
1354
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1355
-
1356
- if not return_dict:
1357
- return tuple(
1358
- v
1359
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
1360
- if v is not None
1361
- )
1362
- return MoeModelOutputWithPast(
1363
- last_hidden_state=hidden_states,
1364
- past_key_values=next_cache,
1365
- hidden_states=all_hidden_states,
1366
- attentions=all_self_attns,
1367
- router_logits=all_router_logits,
1368
- )
1369
-
1370
- # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
1371
- def _update_causal_mask(
1372
- self,
1373
- attention_mask: torch.Tensor,
1374
- input_tensor: torch.Tensor,
1375
- cache_position: torch.Tensor,
1376
- past_key_values: Cache,
1377
- output_attentions: bool,
1378
- ):
1379
- # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1380
- # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1381
- # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1382
- # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1383
-
1384
- if self.config._attn_implementation == "flash_attention_2":
1385
- if attention_mask is not None and 0.0 in attention_mask:
1386
- return attention_mask
1387
- return None
1388
-
1389
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1390
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1391
- # to infer the attention mask.
1392
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1393
- using_static_cache = isinstance(past_key_values, StaticCache)
1394
-
1395
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1396
- if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
1397
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
1398
- attention_mask,
1399
- inputs_embeds=input_tensor,
1400
- past_key_values_length=past_seen_tokens,
1401
- is_training=self.training,
1402
- ):
1403
- return None
1404
-
1405
- dtype, device = input_tensor.dtype, input_tensor.device
1406
- min_dtype = torch.finfo(dtype).min
1407
- sequence_length = input_tensor.shape[1]
1408
- if using_static_cache:
1409
- target_length = past_key_values.get_max_length()
1410
- else:
1411
- target_length = (
1412
- attention_mask.shape[-1]
1413
- if isinstance(attention_mask, torch.Tensor)
1414
- else past_seen_tokens + sequence_length + 1
1415
- )
1416
-
1417
- if attention_mask is not None and attention_mask.dim() == 4:
1418
- # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
1419
- if attention_mask.max() != 0:
1420
- raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
1421
- causal_mask = attention_mask
1422
- else:
1423
- causal_mask = torch.full(
1424
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
1425
- )
1426
- if sequence_length != 1:
1427
- causal_mask = torch.triu(causal_mask, diagonal=1)
1428
- causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1429
- causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
1430
- if attention_mask is not None:
1431
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1432
- mask_length = attention_mask.shape[-1]
1433
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
1434
- padding_mask = padding_mask == 0
1435
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1436
- padding_mask, min_dtype
1437
- )
1438
- if (
1439
- self.config._attn_implementation == "sdpa"
1440
- and attention_mask is not None
1441
- and attention_mask.device.type == "cuda"
1442
- and not output_attentions
1443
- ):
1444
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1445
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1446
- # Details: https://github.com/pytorch/pytorch/issues/110213
1447
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1448
-
1449
- return causal_mask
1450
-
1451
-
1452
- class UpcyclingQwen2MoeForCausalLM(UpcyclingQwen2MoePreTrainedModel):
1453
- _tied_weights_keys = ["lm_head.weight"]
1454
-
1455
- def __init__(self, config):
1456
- super().__init__(config)
1457
- self.model = UpcyclingQwen2MoeModel(config)
1458
- self.vocab_size = config.vocab_size
1459
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1460
-
1461
- self.router_aux_loss_coef = config.router_aux_loss_coef
1462
- self.num_experts = config.num_experts
1463
- self.num_experts_per_tok = config.num_experts_per_tok
1464
-
1465
- self.language_gate=config.language_gate
1466
- # Initialize weights and apply final processing
1467
- self.post_init()
1468
-
1469
- def get_input_embeddings(self):
1470
- return self.model.embed_tokens
1471
-
1472
- def set_input_embeddings(self, value):
1473
- self.model.embed_tokens = value
1474
-
1475
- def get_output_embeddings(self):
1476
- return self.lm_head
1477
-
1478
- def set_output_embeddings(self, new_embeddings):
1479
- self.lm_head = new_embeddings
1480
-
1481
- def set_decoder(self, decoder):
1482
- self.model = decoder
1483
-
1484
- def get_decoder(self):
1485
- return self.model
1486
-
1487
- # @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1488
- def forward(
1489
- self,
1490
- input_ids: torch.LongTensor = None,
1491
- language_ids: Optional[torch.LongTensor] = None,
1492
- attention_mask: Optional[torch.Tensor] = None,
1493
- position_ids: Optional[torch.LongTensor] = None,
1494
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1495
- inputs_embeds: Optional[torch.FloatTensor] = None,
1496
- labels: Optional[torch.LongTensor] = None,
1497
- use_cache: Optional[bool] = None,
1498
- output_attentions: Optional[bool] = None,
1499
- output_hidden_states: Optional[bool] = None,
1500
- output_router_logits: Optional[bool] = None,
1501
- return_dict: Optional[bool] = None,
1502
- cache_position: Optional[torch.LongTensor] = None,
1503
- ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1504
-
1505
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1506
- output_router_logits = (
1507
- output_router_logits if output_router_logits is not None else self.config.output_router_logits
1508
- )
1509
- output_hidden_states = (
1510
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1511
- )
1512
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1513
-
1514
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1515
- outputs = self.model(
1516
- input_ids=input_ids,
1517
- language_ids=language_ids,
1518
- attention_mask=attention_mask,
1519
- position_ids=position_ids,
1520
- past_key_values=past_key_values,
1521
- inputs_embeds=inputs_embeds,
1522
- use_cache=use_cache,
1523
- output_attentions=output_attentions,
1524
- output_hidden_states=output_hidden_states,
1525
- output_router_logits=output_router_logits,
1526
- return_dict=return_dict,
1527
- cache_position=cache_position,
1528
- )
1529
-
1530
- hidden_states = outputs[0]
1531
- logits = self.lm_head(hidden_states)
1532
- logits = logits.float()
1533
-
1534
- loss = None
1535
- if labels is not None:
1536
- # Shift so that tokens < n predict n
1537
- shift_logits = logits[..., :-1, :].contiguous()
1538
- shift_labels = labels[..., 1:].contiguous()
1539
- # Flatten the tokens
1540
- loss_fct = CrossEntropyLoss()
1541
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1542
- shift_labels = shift_labels.view(-1)
1543
- # Enable model parallelism
1544
- shift_labels = shift_labels.to(shift_logits.device)
1545
- loss = loss_fct(shift_logits, shift_labels)
1546
-
1547
- aux_loss = None
1548
- if output_router_logits:
1549
- aux_loss = load_balancing_loss_func(
1550
- outputs.router_logits if return_dict else outputs[-1],
1551
- self.num_experts,
1552
- self.num_experts_per_tok,
1553
- attention_mask,
1554
- )
1555
- if labels is not None:
1556
- loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
1557
-
1558
- if not return_dict:
1559
- output = (logits,) + outputs[1:]
1560
- if output_router_logits:
1561
- output = (aux_loss,) + output
1562
- return (loss,) + output if loss is not None else output
1563
-
1564
- return MoeCausalLMOutputWithPast(
1565
- loss=loss,
1566
- aux_loss=aux_loss,
1567
- logits=logits,
1568
- past_key_values=outputs.past_key_values,
1569
- hidden_states=outputs.hidden_states,
1570
- attentions=outputs.attentions,
1571
- router_logits=outputs.router_logits,
1572
- )
1573
-
1574
- def prepare_inputs_for_generation(
1575
- self,
1576
- input_ids,
1577
- past_key_values=None,
1578
- attention_mask=None,
1579
- inputs_embeds=None,
1580
- cache_position=None,
1581
- use_cache=True,
1582
- **kwargs,
1583
- ):
1584
- past_length = 0
1585
-
1586
- # ##### by own
1587
- if past_key_values is not None:
1588
- if isinstance(past_key_values,Cache):
1589
- # Past key values are always initialized with a `Cache` object -> no need for if-else anymore
1590
- past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
1591
- max_cache_length = (
1592
- torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
1593
- if past_key_values.get_max_length() is not None
1594
- else None
1595
- )
1596
- cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
1597
- else:
1598
- cache_length=past_length=past_key_values[0][0].shape[2]
1599
- max_cache_length=None
1600
- # # #####
1601
- # Omit tokens covered by past_key_values
1602
- # if past_key_values is not None:
1603
- # # Past key values are always initialized with a `Cache` object -> no need for if-else anymore
1604
- # past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
1605
- # max_cache_length = (
1606
- # torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
1607
- # if past_key_values.get_max_length() is not None
1608
- # else None
1609
- # )
1610
- # cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
1611
-
1612
- # Keep only the unprocessed tokens:
1613
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1614
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1615
- # input)
1616
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1617
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1618
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1619
- # input_ids based on the past_length.
1620
- elif past_length < input_ids.shape[1]:
1621
- input_ids = input_ids[:, past_length:]
1622
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1623
-
1624
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1625
- if (
1626
- max_cache_length is not None
1627
- and attention_mask is not None
1628
- and cache_length + input_ids.shape[1] > max_cache_length
1629
- ):
1630
- attention_mask = attention_mask[:, -max_cache_length:]
1631
-
1632
- position_ids = kwargs.get("position_ids", None)
1633
- if attention_mask is not None and position_ids is None:
1634
- # create position_ids on the fly for batch generation
1635
- position_ids = attention_mask.long().cumsum(-1) - 1
1636
- position_ids.masked_fill_(attention_mask == 0, 1)
1637
- if past_key_values:
1638
- position_ids = position_ids[:, -input_ids.shape[1] :]
1639
-
1640
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1641
- if inputs_embeds is not None and past_length == 0:
1642
- model_inputs = {"inputs_embeds": inputs_embeds}
1643
- else:
1644
- model_inputs = {"input_ids": input_ids}
1645
-
1646
- input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
1647
- if cache_position is None:
1648
- cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
1649
- elif use_cache:
1650
- cache_position = cache_position[-input_length:]
1651
-
1652
- model_inputs.update(
1653
- {
1654
- "position_ids": position_ids,
1655
- "past_key_values": past_key_values,
1656
- "use_cache": use_cache,
1657
- "attention_mask": attention_mask,
1658
- "cache_position": cache_position,
1659
- }
1660
- )
1661
- return model_inputs
1662
-
1663
- @staticmethod
1664
- def _reorder_cache(past_key_values, beam_idx):
1665
- reordered_past = ()
1666
- for layer_past in past_key_values:
1667
- reordered_past += (
1668
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1669
- )
1670
- return reordered_past
1671
-
1672
-