Clean Directory and Update Docs

#1
README.md CHANGED
@@ -1,4 +1,11 @@
1
  ---
 
 
 
 
 
 
 
2
  pipeline_tag: text-generation
3
  ---
4
  <div align="center">
@@ -44,15 +51,15 @@ pipeline_tag: text-generation
44
  <a href="https://github.com/MiniMax-AI/MiniMax-01" target="_blank" style="margin: 2px;">
45
  <img alt="GitHub" src="https://img.shields.io/badge/_GitHub-MinMax-FF4040?style=flat-square&labelColor=2C3E50" style="display: inline-block; vertical-align: middle;"/>
46
  </a>
47
- <a href="https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/LICENSE-MODEL" style="margin: 2px;">
48
  <img alt="Model License" src="https://img.shields.io/badge/_Model_License-Model_Agreement-FF4040?style=flat-square&labelColor=2C3E50" style="display: inline-block; vertical-align: middle;"/>
49
  </a>
50
- <a href="https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/LICENSE-CODE" style="margin: 2px;">
51
  <img alt="Code License" src="https://img.shields.io/badge/_Code_License-MIT-FF4040?style=flat-square&labelColor=2C3E50" style="display: inline-block; vertical-align: middle;"/>
52
  </a>
53
  </div>
54
  <div align="center" style="line-height: 1;">
55
- <a href="https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/figures/wechat-qrcode.jpeg" target="_blank" style="margin: 2px;">
56
  WeChat
57
  </a>
58
  </div>
@@ -167,7 +174,7 @@ Here we provide a simple example of loading the tokenizer and model to generate
167
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, QuantoConfig, GenerationConfig
168
 
169
  # load hf config
170
- hf_config = AutoConfig.from_pretrained("MiniMaxAI/MiniMax-Text-01", trust_remote_code=True)
171
 
172
  # quantization config, int8 is recommended
173
  quantization_config = QuantoConfig(
@@ -193,7 +200,7 @@ for i in range(world_size):
193
  device_map[f'model.layers.{i * layers_per_device + j}'] = f'cuda:{i}'
194
 
195
  # load tokenizer
196
- tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01")
197
  prompt = "Hello!"
198
  messages = [
199
  {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant created by MiniMax based on MiniMax-Text-01 model."}]},
@@ -209,11 +216,10 @@ model_inputs = tokenizer(text, return_tensors="pt").to("cuda")
209
 
210
  # load bfloat16 model, move to device, and apply quantization
211
  quantized_model = AutoModelForCausalLM.from_pretrained(
212
- "MiniMaxAI/MiniMax-Text-01",
213
  torch_dtype="bfloat16",
214
  device_map=device_map,
215
  quantization_config=quantization_config,
216
- trust_remote_code=True,
217
  offload_buffers=True,
218
  )
219
 
 
1
  ---
2
+ library_name: transformers
3
+ license: other
4
+ license_name: minimax
5
+ license_link: LICENSE
6
+ tags:
7
+ - moe
8
+ - arxiv:2501.08313
9
  pipeline_tag: text-generation
10
  ---
11
  <div align="center">
 
51
  <a href="https://github.com/MiniMax-AI/MiniMax-01" target="_blank" style="margin: 2px;">
52
  <img alt="GitHub" src="https://img.shields.io/badge/_GitHub-MinMax-FF4040?style=flat-square&labelColor=2C3E50" style="display: inline-block; vertical-align: middle;"/>
53
  </a>
54
+ <a href="https://huggingface.co/MiniMaxAI/MiniMax-Text-01-hf/blob/main/LICENSE-MODEL" style="margin: 2px;">
55
  <img alt="Model License" src="https://img.shields.io/badge/_Model_License-Model_Agreement-FF4040?style=flat-square&labelColor=2C3E50" style="display: inline-block; vertical-align: middle;"/>
56
  </a>
57
+ <a href="https://huggingface.co/MiniMaxAI/MiniMax-Text-01-hf/blob/main/LICENSE-CODE" style="margin: 2px;">
58
  <img alt="Code License" src="https://img.shields.io/badge/_Code_License-MIT-FF4040?style=flat-square&labelColor=2C3E50" style="display: inline-block; vertical-align: middle;"/>
59
  </a>
60
  </div>
61
  <div align="center" style="line-height: 1;">
62
+ <a href="https://huggingface.co/MiniMaxAI/MiniMax-Text-01-hf/blob/main/figures/wechat-qrcode.jpeg" target="_blank" style="margin: 2px;">
63
  WeChat
64
  </a>
65
  </div>
 
174
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, QuantoConfig, GenerationConfig
175
 
176
  # load hf config
177
+ hf_config = AutoConfig.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
178
 
179
  # quantization config, int8 is recommended
180
  quantization_config = QuantoConfig(
 
200
  device_map[f'model.layers.{i * layers_per_device + j}'] = f'cuda:{i}'
201
 
202
  # load tokenizer
203
+ tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
204
  prompt = "Hello!"
205
  messages = [
206
  {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant created by MiniMax based on MiniMax-Text-01 model."}]},
 
216
 
217
  # load bfloat16 model, move to device, and apply quantization
218
  quantized_model = AutoModelForCausalLM.from_pretrained(
219
+ "MiniMaxAI/MiniMax-Text-01-hf",
220
  torch_dtype="bfloat16",
221
  device_map=device_map,
222
  quantization_config=quantization_config,
 
223
  offload_buffers=True,
224
  )
225
 
configuration_minimax_text_01.py DELETED
@@ -1,152 +0,0 @@
1
- """ MiniMaxText01 model configuration"""
2
-
3
- from transformers.configuration_utils import PretrainedConfig
4
- from transformers.utils import logging
5
-
6
-
7
- logger = logging.get_logger(__name__)
8
-
9
-
10
- class MiniMaxText01Config(PretrainedConfig):
11
- r"""
12
- This is the configuration class to store the configuration of a [`MiniMaxText01Model`]. It is used to instantiate an
13
- MiniMaxText01 model according to the specified arguments, defining the model architecture. Instantiating a configuration
14
- with the defaults will yield a similar configuration to that of the MiniMaxText01.
15
-
16
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
17
- documentation from [`PretrainedConfig`] for more information.
18
-
19
-
20
- Args:
21
- vocab_size (`int`, *optional*, defaults to 32000):
22
- Vocabulary size of the MiniMaxText01 model. Defines the number of different tokens that can be represented by the
23
- `inputs_ids` passed when calling [`MiniMaxText01Model`]
24
- hidden_size (`int`, *optional*, defaults to 4096):
25
- Dimension of the hidden representations.
26
- intermediate_size (`int`, *optional*, defaults to 14336):
27
- Dimension of the MLP representations.
28
- num_hidden_layers (`int`, *optional*, defaults to 32):
29
- Number of hidden layers in the Transformer encoder.
30
- num_attention_heads (`int`, *optional*, defaults to 32):
31
- Number of attention heads for each attention layer in the Transformer encoder.
32
- num_key_value_heads (`int`, *optional*, defaults to 8):
33
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
34
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
35
- `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
36
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
37
- by meanpooling all the original heads within that group. For more details checkout [this
38
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
39
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
40
- The non-linear activation function (function or string) in the decoder.
41
- max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
42
- The maximum sequence length that this model might ever be used with. MiniMaxText01's sliding window attention
43
- allows sequence of up to 4096*32 tokens.
44
- initializer_range (`float`, *optional*, defaults to 0.02):
45
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
46
- rms_norm_eps (`float`, *optional*, defaults to 1e-05):
47
- The epsilon used by the rms normalization layers.
48
- use_cache (`bool`, *optional*, defaults to `True`):
49
- Whether or not the model should return the last key/values attentions (not used by all models). Only
50
- relevant if `config.is_decoder=True`.
51
- pad_token_id (`int`, *optional*):
52
- The id of the padding token.
53
- bos_token_id (`int`, *optional*, defaults to 1):
54
- The id of the "beginning-of-sequence" token.
55
- eos_token_id (`int`, *optional*, defaults to 2):
56
- The id of the "end-of-sequence" token.
57
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
58
- Whether the model's input and output word embeddings should be tied.
59
- rope_theta (`float`, *optional*, defaults to 1000000.0):
60
- The base period of the RoPE embeddings.
61
- sliding_window (`int`, *optional*):
62
- Sliding window attention window size. If not specified, will default to `4096`.
63
- attention_dropout (`float`, *optional*, defaults to 0.0):
64
- The dropout ratio for the attention probabilities.
65
- num_experts_per_tok (`int`, *optional*, defaults to 2):
66
- The number of experts to route per-token, can be also interpreted as the `top-k` routing
67
- parameter
68
- num_local_experts (`int`, *optional*, defaults to 8):
69
- Number of experts per Sparse MLP layer.
70
- output_router_logits (`bool`, *optional*, defaults to `False`):
71
- Whether or not the router logits should be returned by the model. Enabeling this will also
72
- allow the model to output the auxiliary loss. See [here]() for more details
73
- router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
74
- The aux loss factor for the total loss.
75
- router_jitter_noise (`float`, *optional*, defaults to 0.0):
76
- Amount of noise to add to the router.
77
-
78
- ```python
79
- >>> from transformers import MiniMaxText01Model, MiniMaxText01Config
80
-
81
- >>> # Initializing a MiniMaxText01 style configuration
82
- >>> configuration = MiniMaxText01Config()
83
-
84
- >>> # Initializing a model from the MiniMaxText01 style configuration
85
- >>> model = MiniMaxText01Model(configuration)
86
-
87
- >>> # Accessing the model configuration
88
- >>> configuration = model.config
89
- ```"""
90
-
91
- model_type = "MiniMaxText01"
92
- keys_to_ignore_at_inference = ["past_key_values"]
93
-
94
- def __init__(
95
- self,
96
- vocab_size=32000,
97
- hidden_size=4096,
98
- intermediate_size=14336,
99
- num_hidden_layers=32,
100
- num_attention_heads=32,
101
- num_key_value_heads=8,
102
- hidden_act="silu",
103
- max_position_embeddings=4096 * 32,
104
- initializer_range=0.02,
105
- rms_norm_eps=1e-5,
106
- use_cache=True,
107
- pad_token_id=None,
108
- bos_token_id=None,
109
- eos_token_id=None,
110
- tie_word_embeddings=False,
111
- rope_theta=1e6,
112
- sliding_window=None,
113
- attention_dropout=0.0,
114
- num_experts_per_tok=2,
115
- num_local_experts=8,
116
- output_router_logits=False,
117
- router_aux_loss_coef=0.001,
118
- router_jitter_noise=0.0,
119
- **kwargs,
120
- ):
121
- self.vocab_size = vocab_size
122
- self.max_position_embeddings = max_position_embeddings
123
- self.hidden_size = hidden_size
124
- self.intermediate_size = intermediate_size
125
- self.num_hidden_layers = num_hidden_layers
126
- self.num_attention_heads = num_attention_heads
127
- self.sliding_window = sliding_window
128
-
129
- # for backward compatibility
130
- if num_key_value_heads is None:
131
- num_key_value_heads = num_attention_heads
132
-
133
- self.num_key_value_heads = num_key_value_heads
134
- self.hidden_act = hidden_act
135
- self.initializer_range = initializer_range
136
- self.rms_norm_eps = rms_norm_eps
137
- self.use_cache = use_cache
138
- self.rope_theta = rope_theta
139
- self.attention_dropout = attention_dropout
140
-
141
- self.num_experts_per_tok = num_experts_per_tok
142
- self.num_local_experts = num_local_experts
143
- self.output_router_logits = output_router_logits
144
- self.router_aux_loss_coef = router_aux_loss_coef
145
- self.router_jitter_noise = router_jitter_noise
146
- super().__init__(
147
- pad_token_id=pad_token_id,
148
- bos_token_id=bos_token_id,
149
- eos_token_id=eos_token_id,
150
- tie_word_embeddings=tie_word_embeddings,
151
- **kwargs,
152
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_minimax_text_01.py DELETED
@@ -1,1701 +0,0 @@
1
- """ PyTorch MiniMaxText01 model."""
2
- import inspect
3
- import math
4
- import warnings
5
- from typing import List, Optional, Tuple, Union
6
- import os
7
- import copy
8
- import torch
9
- import torch.nn.functional as F
10
- import torch.utils.checkpoint
11
- from torch import nn
12
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
13
- from einops import rearrange, repeat
14
- from transformers.activations import ACT2FN
15
- from transformers.cache_utils import Cache, DynamicCache
16
- from transformers.modeling_attn_mask_utils import (
17
- _prepare_4d_causal_attention_mask,
18
- )
19
- from transformers.modeling_outputs import (
20
- MoeCausalLMOutputWithPast,
21
- MoeModelOutputWithPast,
22
- SequenceClassifierOutputWithPast,
23
- )
24
- from transformers.modeling_utils import PreTrainedModel
25
- from transformers.utils import (
26
- add_start_docstrings,
27
- add_start_docstrings_to_model_forward,
28
- is_flash_attn_2_available,
29
- is_flash_attn_greater_or_equal_2_10,
30
- logging,
31
- replace_return_docstrings,
32
- )
33
- from transformers.utils.import_utils import is_torch_fx_available
34
- from .configuration_minimax_text_01 import MiniMaxText01Config
35
-
36
- if is_flash_attn_2_available():
37
- from flash_attn import flash_attn_func, flash_attn_varlen_func
38
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
39
-
40
- _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
41
-
42
- # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
43
- # It means that the function will not be traced through and simply appear as a node in the graph.
44
- if is_torch_fx_available():
45
- _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
46
-
47
- use_triton = eval(os.environ.get("use_triton", default="False"))
48
- debug = eval(os.environ.get("debug", default="False"))
49
- do_eval = eval(os.environ.get("do_eval", default="False"))
50
- eval_and_not_generate = eval(os.environ.get("eval_and_not_generate", default="False"))
51
- BLOCK = 256
52
-
53
- logger = logging.get_logger(__name__)
54
-
55
- _CONFIG_FOR_DOC = "MiniMaxText01Config"
56
-
57
-
58
- def get_activation_fn(activation):
59
- if debug:
60
- logger.info(f"activation: {activation}")
61
- if activation == "gelu":
62
- return F.gelu
63
- elif activation == "relu":
64
- return F.relu
65
- elif activation == "elu":
66
- return F.elu
67
- elif activation == "sigmoid":
68
- return F.sigmoid
69
- elif activation == "exp":
70
-
71
- def f(x):
72
- with torch.no_grad():
73
- x_max = torch.max(x, dim=-1, keepdims=True).values
74
- y = torch.exp(x - x_max)
75
-
76
- return y
77
-
78
- return f
79
- elif activation == "leak":
80
- return F.leaky_relu
81
- elif activation == "1+elu":
82
-
83
- def f(x):
84
- return 1 + F.elu(x)
85
-
86
- return f
87
- elif activation == "2+elu":
88
-
89
- def f(x):
90
- return 2 + F.elu(x)
91
-
92
- return f
93
- elif activation == "silu" or activation == "swish":
94
- return F.silu
95
- elif activation == "sine":
96
- return torch.sin
97
- else:
98
- logger.info(
99
- f"activation: does not support {activation}, use Identity!!!")
100
- return lambda x: x
101
-
102
-
103
- def load_balancing_loss_func(
104
- gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2,
105
- attention_mask: Optional[torch.Tensor] = None
106
- ) -> float:
107
- r"""
108
- Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
109
-
110
- See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
111
- function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
112
- experts is too unbalanced.
113
-
114
- Args:
115
- gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
116
- Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
117
- shape [batch_size X sequence_length, num_experts].
118
- attention_mask (`torch.Tensor`, None):
119
- The attention_mask used in forward function
120
- shape [batch_size X sequence_length] if not None.
121
- num_experts (`int`, *optional*):
122
- Number of experts
123
-
124
- Returns:
125
- The auxiliary loss.
126
- """
127
- if gate_logits is None or not isinstance(gate_logits, tuple):
128
- return 0
129
-
130
- if isinstance(gate_logits, tuple):
131
- compute_device = gate_logits[0].device
132
- concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
133
-
134
- routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
135
-
136
- _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
137
-
138
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
139
-
140
- if attention_mask is None:
141
- # Compute the percentage of tokens routed to each experts
142
- tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
143
-
144
- # Compute the average probability of routing to these experts
145
- router_prob_per_expert = torch.mean(routing_weights, dim=0)
146
- else:
147
- batch_size, sequence_length = attention_mask.shape
148
- num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
149
-
150
- # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
151
- expert_attention_mask = (
152
- attention_mask[None, :, :, None, None]
153
- .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
154
- .reshape(-1, top_k, num_experts)
155
- .to(compute_device)
156
- )
157
-
158
- # Compute the percentage of tokens routed to each experts
159
- tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
160
- expert_attention_mask, dim=0
161
- )
162
-
163
- # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
164
- router_per_expert_attention_mask = (
165
- attention_mask[None, :, :, None]
166
- .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
167
- .reshape(-1, num_experts)
168
- .to(compute_device)
169
- )
170
-
171
- # Compute the average probability of routing to these experts
172
- router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
173
- router_per_expert_attention_mask, dim=0
174
- )
175
-
176
- overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
177
- return overall_loss * num_experts
178
-
179
-
180
- # Copied from transformers.models.llama.modeling_llama._get_unpad_data
181
- def _get_unpad_data(attention_mask):
182
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
183
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
184
- max_seqlen_in_batch = seqlens_in_batch.max().item()
185
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
186
- return (
187
- indices,
188
- cu_seqlens,
189
- max_seqlen_in_batch,
190
- )
191
-
192
-
193
- class GLU(nn.Module):
194
-
195
- def __init__(self, d1, d2, bias=False):
196
- super().__init__()
197
-
198
- self.l1 = nn.Linear(d1, d2, bias=bias)
199
- self.l2 = nn.Linear(d1, d2, bias=bias)
200
- self.l3 = nn.Linear(d2, d1, bias=bias)
201
-
202
- def forward(self, x):
203
- o1 = self.l1(x)
204
- o2 = self.l2(x)
205
- output = o1 * o2
206
- output = self.l3(output)
207
- return output
208
-
209
-
210
- class MiniMaxText01LightningAttention(nn.Module):
211
- def __init__(self, config: MiniMaxText01Config, layer_idx: Optional[int] = None):
212
- super().__init__()
213
- bias = False
214
- self.hidden_size = config.hidden_size
215
- self.num_heads = config.num_attention_heads
216
- self.head_dim = getattr(config, 'head_dim', self.hidden_size // self.num_heads)
217
-
218
- self.out_proj = nn.Linear(self.head_dim * self.num_heads, self.hidden_size, bias=bias)
219
- self.act = get_activation_fn(config.hidden_act)
220
- self.norm = MiniMaxText01RMSNorm(self.head_dim * self.num_heads)
221
-
222
- self.qkv_proj = nn.Linear(self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias)
223
- self.output_gate = nn.Linear(self.hidden_size, self.head_dim * self.num_heads, bias=bias)
224
-
225
- # for inference only
226
- self.offset = 0
227
- self.layer_idx = layer_idx
228
-
229
- def forward(
230
- self,
231
- hidden_states,
232
- attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m)
233
- output_attentions: bool = False,
234
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
235
- use_cache: bool = False,
236
- slope_rate: Optional[torch.Tensor] = None,
237
- **kwargs
238
- ):
239
- if (not self.training) and (not do_eval):
240
- return self.inference(
241
- hidden_states,
242
- attn_mask,
243
- output_attentions,
244
- past_key_value,
245
- use_cache,
246
- slope_rate,
247
- )
248
-
249
- def inference(
250
- self,
251
- x,
252
- attn_mask: Optional[torch.Tensor] = None, # (b, n)
253
- output_attentions: bool = False,
254
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
255
- use_cache: bool = False,
256
- slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
257
- ):
258
- # x: b n d
259
- b, n, d = x.shape
260
- # linear map
261
- qkv = self.act(self.qkv_proj(x))
262
- new_shape = qkv.size()[:-1] + (self.num_heads, -1)
263
- qkv = qkv.view(*new_shape)
264
- q, k, v = torch.split(qkv, [self.head_dim] * 3, dim=3)
265
- q = q.transpose(1, 2)
266
- k = k.transpose(1, 2)
267
- v = v.transpose(1, 2)
268
-
269
- if past_key_value is None:
270
- self.offset = q.shape[-2]
271
- else:
272
- self.offset += 1
273
-
274
- # for align with metaseq
275
- ratio = torch.exp(-slope_rate)
276
-
277
- # only use for the first time
278
- if past_key_value is None:
279
- slope_rate = slope_rate.to(torch.float32)
280
- if attn_mask is not None:
281
- v = v.masked_fill((1 - attn_mask).unsqueeze(1).unsqueeze(-1).to(torch.bool), 0)
282
- NUM_BLOCK = (n + BLOCK - 1) // BLOCK
283
- b, h, n, d = q.shape
284
- e = v.shape[-1]
285
- # other
286
- array = torch.arange(BLOCK).to(q) + 1
287
- q_decay = torch.exp(-slope_rate * array.reshape(-1, 1))
288
- k_decay = torch.exp(-slope_rate * (BLOCK - array.reshape(-1, 1)))
289
- index = array[:, None] - array[None, :]
290
- s_index = slope_rate * index[
291
- None,
292
- None,
293
- ]
294
- s_index = torch.where(index >= 0, -s_index, float("-inf"))
295
- diag_decay = torch.exp(s_index)
296
-
297
- kv = torch.zeros(b, h, d, e).to(torch.float32).to(q.device)
298
- output = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)
299
- for i in range(NUM_BLOCK):
300
- si = i * BLOCK
301
- ei = min(si + BLOCK, n)
302
- m = ei - si
303
- qi = q[:, :, si:ei].contiguous()
304
- ki = k[:, :, si:ei].contiguous()
305
- vi = v[:, :, si:ei].contiguous()
306
- qkv_none_diag = torch.matmul(qi * q_decay[:, :m], kv).to(torch.float32)
307
-
308
- # diag
309
- qk = torch.matmul(qi, ki.transpose(-1, -2)).to(torch.float32) * diag_decay[:, :, :m, :m]
310
- qkv_diag = torch.matmul(qk, vi.to(torch.float32))
311
- block_decay = torch.exp(-slope_rate * m)
312
- output[:, :, si:ei] = qkv_none_diag + qkv_diag
313
- kv = block_decay * kv + torch.matmul((ki * k_decay[:, -m:]).transpose(-1, -2).to(vi.dtype), vi)
314
-
315
- else:
316
- kv = past_key_value
317
- output = []
318
- for i in range(n):
319
- kv = ratio * kv + torch.einsum(
320
- "... n d, ... n e -> ... d e",
321
- k[:, :, i:i + 1],
322
- v[:, :, i:i + 1],
323
- )
324
- qkv = torch.einsum("... n e, ... e d -> ... n d", q[:, :, i:i + 1], kv.to(q.dtype))
325
- output.append(qkv)
326
- output = torch.concat(output, dim=-2)
327
- # reshape
328
- output = rearrange(output, "b h n d -> b n (h d)")
329
- # normalize
330
- output = self.norm(output)
331
- # gate
332
- output = F.sigmoid(self.output_gate(x)) * output
333
- # outproj
334
- output = self.out_proj(output)
335
-
336
- attn_weights = None
337
-
338
- return output, attn_weights, kv
339
-
340
-
341
- # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MiniMaxText01
342
- class MiniMaxText01RMSNorm(nn.Module):
343
- def __init__(self, hidden_size, eps=1e-6):
344
- """
345
- MiniMaxText01RMSNorm is equivalent to T5LayerNorm
346
- """
347
- super().__init__()
348
- self.weight = nn.Parameter(torch.ones(hidden_size))
349
- self.variance_epsilon = eps
350
-
351
- def forward(self, hidden_states):
352
- input_dtype = hidden_states.dtype
353
- hidden_states = hidden_states.to(torch.float32)
354
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
355
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
356
- return self.weight * hidden_states.to(input_dtype)
357
-
358
-
359
- # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->MiniMaxText01
360
- class MiniMaxText01RotaryEmbedding(nn.Module):
361
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
362
- super().__init__()
363
-
364
- self.dim = dim
365
- self.max_position_embeddings = max_position_embeddings
366
- self.base = base
367
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
368
- self.register_buffer("inv_freq", inv_freq, persistent=False)
369
-
370
- # Build here to make `torch.jit.trace` work.
371
- self._set_cos_sin_cache(
372
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32
373
- )
374
-
375
- def _set_cos_sin_cache(self, seq_len, device, dtype):
376
- self.max_seq_len_cached = seq_len
377
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
378
-
379
- freqs = torch.outer(t, self.inv_freq)
380
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
381
- emb = torch.cat((freqs, freqs), dim=-1)
382
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
383
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
384
-
385
- def forward(self, x, seq_len=None):
386
- # x: [bs, num_attention_heads, seq_len, head_size]
387
- if seq_len > self.max_seq_len_cached:
388
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.float32)
389
-
390
- return (
391
- self.cos_cached[:seq_len].to(dtype=torch.float32),
392
- self.sin_cached[:seq_len].to(dtype=torch.float32),
393
- )
394
-
395
-
396
- # Copied from transformers.models.llama.modeling_llama.rotate_half
397
- def rotate_half(x):
398
- """Rotates half the hidden dims of the input."""
399
- x1 = x[..., : x.shape[-1] // 2]
400
- x2 = x[..., x.shape[-1] // 2:]
401
- return torch.cat((-x2, x1), dim=-1)
402
-
403
-
404
- # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
405
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
406
- """Applies Rotary Position Embedding to the query and key tensors.
407
-
408
- Args:
409
- q (`torch.Tensor`): The query tensor.
410
- k (`torch.Tensor`): The key tensor.
411
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
412
- sin (`torch.Tensor`): The sine part of the rotary embedding.
413
- position_ids (`torch.Tensor`):
414
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
415
- used to pass offsetted position ids when working with a KV-cache.
416
- unsqueeze_dim (`int`, *optional*, defaults to 1):
417
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
418
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
419
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
420
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
421
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
422
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
423
- Returns:
424
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
425
- """
426
- dtype = q.dtype
427
- rot_dim = cos.shape[-1]
428
- q_, q_pass = q[..., :rot_dim], q[..., rot_dim:]
429
- k_, k_pass = k[..., :rot_dim], k[..., rot_dim:]
430
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
431
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
432
- q_embed = (q_ * cos) + (rotate_half(q_) * sin)
433
- k_embed = (k_ * cos) + (rotate_half(k_) * sin)
434
- return torch.cat((q_embed, q_pass), dim=-1).to(dtype), torch.cat((k_embed, k_pass), dim=-1).to(dtype)
435
-
436
-
437
- # Copied from transformers.models.llama.modeling_llama.repeat_kv
438
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
439
- """
440
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
441
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
442
- """
443
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
444
- if n_rep == 1:
445
- return hidden_states
446
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
447
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
448
-
449
-
450
- # Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->MiniMaxText01
451
- class MiniMaxText01Attention(nn.Module):
452
- """
453
- Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
454
- and "Generating Long Sequences with Sparse Transformers".
455
- """
456
-
457
- def __init__(self, config: MiniMaxText01Config, layer_idx: Optional[int] = None):
458
- super().__init__()
459
- self.config = config
460
- self.layer_idx = layer_idx
461
- if layer_idx is None:
462
- logger.warning_once(
463
- f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
464
- "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
465
- "when creating this class."
466
- )
467
-
468
- self.hidden_size = config.hidden_size
469
- self.num_heads = config.num_attention_heads
470
- self.head_dim = getattr(config, 'head_dim', self.hidden_size // self.num_heads)
471
- self.num_key_value_heads = config.num_key_value_heads
472
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
473
- self.max_position_embeddings = config.max_position_embeddings
474
- self.rope_theta = config.rope_theta
475
- self.is_causal = True
476
- self.attention_dropout = config.attention_dropout
477
-
478
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
479
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
480
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
481
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
482
- self.rotary_dim = getattr(config, 'rotary_dim', self.head_dim)
483
-
484
- self.rotary_emb = MiniMaxText01RotaryEmbedding(
485
- self.rotary_dim,
486
- max_position_embeddings=self.max_position_embeddings,
487
- base=self.rope_theta,
488
- )
489
-
490
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
491
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
492
-
493
- def forward(
494
- self,
495
- hidden_states: torch.Tensor,
496
- attention_mask: Optional[torch.Tensor] = None,
497
- position_ids: Optional[torch.LongTensor] = None,
498
- past_key_value: Optional[Cache] = None,
499
- output_attentions: bool = False,
500
- use_cache: bool = False,
501
- **kwargs,
502
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
503
- if "padding_mask" in kwargs:
504
- warnings.warn(
505
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
506
- )
507
- bsz, q_len, _ = hidden_states.size()
508
-
509
- query_states = self.q_proj(hidden_states)
510
- key_states = self.k_proj(hidden_states)
511
- value_states = self.v_proj(hidden_states)
512
-
513
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
514
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
515
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
516
-
517
- kv_seq_len = key_states.shape[-2]
518
- if past_key_value is not None:
519
- if self.layer_idx is None:
520
- raise ValueError(
521
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
522
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
523
- "with a layer index."
524
- )
525
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
526
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
527
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
528
-
529
- if past_key_value is not None:
530
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
531
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
532
-
533
- # repeat k/v heads if n_kv_heads < n_heads
534
- key_states = repeat_kv(key_states, self.num_key_value_groups)
535
- value_states = repeat_kv(value_states, self.num_key_value_groups)
536
-
537
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
538
-
539
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
540
- raise ValueError(
541
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
542
- f" {attn_weights.size()}"
543
- )
544
-
545
- if attention_mask is not None:
546
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
547
- raise ValueError(
548
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
549
- )
550
-
551
- attn_weights = attn_weights + attention_mask
552
-
553
- # upcast attention to fp32
554
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
555
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
556
- attn_output = torch.matmul(attn_weights, value_states)
557
-
558
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
559
- raise ValueError(
560
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
561
- f" {attn_output.size()}"
562
- )
563
-
564
- attn_output = attn_output.transpose(1, 2).contiguous()
565
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
566
-
567
- attn_output = self.o_proj(attn_output)
568
-
569
- if not output_attentions:
570
- attn_weights = None
571
-
572
- return attn_output, attn_weights, past_key_value
573
-
574
-
575
- # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->MiniMaxText01
576
- class MiniMaxText01FlashAttention2(MiniMaxText01Attention):
577
- """
578
- MiniMaxText01 flash attention module. This module inherits from `MiniMaxText01Attention` as the weights of the module stays
579
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
580
- flash attention and deal with padding tokens in case the input contains any of them.
581
- """
582
-
583
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
584
- def __init__(self, *args, **kwargs):
585
- super().__init__(*args, **kwargs)
586
-
587
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
588
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
589
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
590
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
591
-
592
- def forward(
593
- self,
594
- hidden_states: torch.Tensor,
595
- attention_mask: Optional[torch.Tensor] = None,
596
- position_ids: Optional[torch.LongTensor] = None,
597
- past_key_value: Optional[Union[Cache, Tuple[torch.Tensor]]] = None,
598
- output_attentions: bool = False,
599
- use_cache: bool = False,
600
- **kwargs,
601
- ):
602
- if "padding_mask" in kwargs:
603
- warnings.warn(
604
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
605
- )
606
-
607
- # overwrite attention_mask with padding_mask
608
- attention_mask = kwargs.pop("padding_mask")
609
- bsz, q_len, _ = hidden_states.size()
610
-
611
- query_states = self.q_proj(hidden_states)
612
- key_states = self.k_proj(hidden_states)
613
- value_states = self.v_proj(hidden_states)
614
-
615
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
616
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
617
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
618
-
619
- kv_seq_len = key_states.shape[-2]
620
- if past_key_value is not None:
621
- kv_seq_len += past_key_value[0].shape[-3]
622
-
623
- # Because the input can be padded, the absolute sequence length depends on the max position id.
624
- rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
625
- cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
626
-
627
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
628
-
629
- use_sliding_windows = (
630
- _flash_supports_window_size
631
- and getattr(self.config, "sliding_window", None) is not None
632
- and kv_seq_len > self.config.sliding_window
633
- )
634
-
635
- if not _flash_supports_window_size:
636
- logger.warning_once(
637
- "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
638
- " make sure to upgrade flash-attn library."
639
- )
640
-
641
- dropout_rate = 0.0 if not self.training else self.attention_dropout
642
-
643
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
644
- # therefore the input hidden states gets silently casted in float32. Hence, we need
645
- # cast them back in float16 just to be sure everything works as expected.
646
- input_dtype = query_states.dtype
647
- if input_dtype == torch.float32:
648
- if torch.is_autocast_enabled():
649
- target_dtype = torch.get_autocast_gpu_dtype()
650
- # Handle the case where the model is quantized
651
- elif hasattr(self.config, "_pre_quantization_dtype"):
652
- target_dtype = self.config._pre_quantization_dtype
653
- else:
654
- target_dtype = self.q_proj.weight.dtype
655
-
656
- logger.warning_once(
657
- f"The input hidden states seems to be silently casted in float32, this might be related to"
658
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
659
- f" {target_dtype}."
660
- )
661
-
662
- query_states = query_states.to(target_dtype)
663
- key_states = key_states.to(target_dtype)
664
- value_states = value_states.to(target_dtype)
665
-
666
- # Reshape to the expected shape for Flash Attention
667
- query_states = query_states.transpose(1, 2)
668
- key_states = key_states.transpose(1, 2)
669
- value_states = value_states.transpose(1, 2)
670
-
671
- if past_key_value is not None:
672
- # reuse k, v, for evaluation only
673
- key_states = torch.cat([past_key_value[0], key_states], dim=-3)
674
- value_states = torch.cat([past_key_value[1], value_states], dim=-3)
675
-
676
- past_key_value = (key_states, value_states) if use_cache else None
677
-
678
- attn_output = self._flash_attention_forward(
679
- query_states,
680
- key_states,
681
- value_states,
682
- attention_mask,
683
- q_len,
684
- dropout=dropout_rate,
685
- use_sliding_windows=use_sliding_windows,
686
- )
687
-
688
- attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
689
- attn_output = self.o_proj(attn_output)
690
-
691
- if not output_attentions:
692
- attn_weights = None
693
-
694
- return attn_output, attn_weights, past_key_value
695
-
696
- def _flash_attention_forward(
697
- self,
698
- query_states,
699
- key_states,
700
- value_states,
701
- attention_mask,
702
- query_length,
703
- dropout=0.0,
704
- softmax_scale=None,
705
- use_sliding_windows=False,
706
- ):
707
- """
708
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
709
- first unpad the input, then computes the attention scores and pad the final attention scores.
710
-
711
- Args:
712
- query_states (`torch.Tensor`):
713
- Input query states to be passed to Flash Attention API
714
- key_states (`torch.Tensor`):
715
- Input key states to be passed to Flash Attention API
716
- value_states (`torch.Tensor`):
717
- Input value states to be passed to Flash Attention API
718
- attention_mask (`torch.Tensor`):
719
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
720
- position of padding tokens and 1 for the position of non-padding tokens.
721
- dropout (`float`):
722
- Attention dropout
723
- softmax_scale (`float`, *optional*):
724
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
725
- use_sliding_windows (`bool`, *optional*):
726
- Whether to activate sliding window attention.
727
- """
728
- if not self._flash_attn_uses_top_left_mask:
729
- causal = self.is_causal
730
- else:
731
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
732
- causal = self.is_causal and query_length != 1
733
-
734
- # Contains at least one padding token in the sequence
735
- if attention_mask is not None:
736
- batch_size = query_states.shape[0]
737
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
738
- query_states, key_states, value_states, attention_mask, query_length
739
- )
740
-
741
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
742
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
743
-
744
- if not use_sliding_windows:
745
- attn_output_unpad = flash_attn_varlen_func(
746
- query_states,
747
- key_states,
748
- value_states,
749
- cu_seqlens_q=cu_seqlens_q,
750
- cu_seqlens_k=cu_seqlens_k,
751
- max_seqlen_q=max_seqlen_in_batch_q,
752
- max_seqlen_k=max_seqlen_in_batch_k,
753
- dropout_p=dropout,
754
- softmax_scale=softmax_scale,
755
- causal=causal,
756
- )
757
- else:
758
- attn_output_unpad = flash_attn_varlen_func(
759
- query_states,
760
- key_states,
761
- value_states,
762
- cu_seqlens_q=cu_seqlens_q,
763
- cu_seqlens_k=cu_seqlens_k,
764
- max_seqlen_q=max_seqlen_in_batch_q,
765
- max_seqlen_k=max_seqlen_in_batch_k,
766
- dropout_p=dropout,
767
- softmax_scale=softmax_scale,
768
- causal=causal,
769
- window_size=(self.config.sliding_window, self.config.sliding_window),
770
- )
771
-
772
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
773
- else:
774
- if not use_sliding_windows:
775
- attn_output = flash_attn_func(
776
- query_states,
777
- key_states,
778
- value_states,
779
- dropout,
780
- softmax_scale=softmax_scale,
781
- causal=causal,
782
- )
783
- else:
784
- attn_output = flash_attn_func(
785
- query_states,
786
- key_states,
787
- value_states,
788
- dropout,
789
- softmax_scale=softmax_scale,
790
- causal=causal,
791
- window_size=(self.config.sliding_window, self.config.sliding_window),
792
- )
793
-
794
- return attn_output
795
-
796
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
797
- batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
798
-
799
- # On the first iteration we need to properly re-create the padding mask
800
- # by slicing it on the proper place
801
- if kv_seq_len != attention_mask.shape[-1]:
802
- attention_mask_num_tokens = attention_mask.shape[-1]
803
- attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len:]
804
-
805
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
806
-
807
- key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
808
- value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
809
-
810
- if query_length == kv_seq_len:
811
- query_layer = index_first_axis(
812
- query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
813
- )
814
- cu_seqlens_q = cu_seqlens_k
815
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
816
- indices_q = indices_k
817
- elif query_length == 1:
818
- max_seqlen_in_batch_q = 1
819
- cu_seqlens_q = torch.arange(
820
- batch_size + 1, dtype=torch.int32, device=query_layer.device
821
- ) # There is a memcpy here, that is very bad.
822
- indices_q = cu_seqlens_q[:-1]
823
- query_layer = query_layer.squeeze(1)
824
- else:
825
- # The -q_len: slice assumes left padding.
826
- attention_mask = attention_mask[:, -query_length:]
827
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
828
-
829
- return (
830
- query_layer,
831
- key_layer,
832
- value_layer,
833
- indices_q,
834
- (cu_seqlens_q, cu_seqlens_k),
835
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
836
- )
837
-
838
-
839
- class MiniMaxText01MLP(nn.Module):
840
- def __init__(self, config):
841
- super().__init__()
842
- self.config = config
843
- self.hidden_size = config.hidden_size
844
- self.intermediate_size = config.intermediate_size
845
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
846
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
847
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
848
- self.act_fn = ACT2FN[config.hidden_act]
849
-
850
- def forward(self, x):
851
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
852
- return down_proj
853
-
854
-
855
- class MiniMaxText01BlockSparseTop2MLP(nn.Module):
856
- def __init__(self, config: MiniMaxText01Config):
857
- super().__init__()
858
- self.ffn_dim = config.intermediate_size
859
- self.hidden_dim = config.hidden_size
860
-
861
- self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
862
- self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
863
- self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
864
-
865
- self.act_fn = ACT2FN[config.hidden_act]
866
-
867
- def forward(self, hidden_states):
868
- current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
869
- current_hidden_states = self.w2(current_hidden_states)
870
- return current_hidden_states
871
-
872
-
873
- class MiniMaxText01BLockSparseTop2MLP(MiniMaxText01BlockSparseTop2MLP):
874
- def __init__(self, *args, **kwargs):
875
- logger.warning_once(
876
- "MiniMaxText01BLockSparseTop2MLP is deprecated by MiniMaxText01BlockSparseTop2MLP and will be removed in v4.40."
877
- )
878
- super().__init__(*args, **kwargs)
879
-
880
-
881
- class MiniMaxText01SparseMoeBlock(nn.Module):
882
- """
883
- This implementation is
884
- strictly equivalent to standard MoE with full capacity (no
885
- dropped tokens). It's faster since it formulates MoE operations
886
- in terms of block-sparse operations to accomodate imbalanced
887
- assignments of tokens to experts, whereas standard MoE either
888
- (1) drop tokens at the cost of reduced performance or (2) set
889
- capacity factor to number of experts and thus waste computation
890
- and memory on padding.
891
- """
892
-
893
- def __init__(self, config):
894
- super().__init__()
895
- self.hidden_dim = config.hidden_size
896
- self.ffn_dim = config.intermediate_size
897
- self.num_experts = config.num_local_experts
898
- self.top_k = config.num_experts_per_tok
899
-
900
- # gating
901
- self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
902
-
903
- self.experts = nn.ModuleList([MiniMaxText01BlockSparseTop2MLP(config) for _ in range(self.num_experts)])
904
-
905
- # Jitter parameters
906
- self.jitter_noise = config.router_jitter_noise
907
-
908
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
909
- """ """
910
- batch_size, sequence_length, hidden_dim = hidden_states.shape
911
- if self.training and self.jitter_noise > 0:
912
- hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
913
- hidden_states = hidden_states.view(-1, hidden_dim)
914
- # router_logits: (batch * sequence_length, n_experts)
915
- router_logits = self.gate(hidden_states)
916
-
917
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
918
- routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
919
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
920
- # we cast back to the input dtype
921
- routing_weights = routing_weights.to(hidden_states.dtype)
922
-
923
- final_hidden_states = torch.zeros(
924
- (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
925
- )
926
-
927
- # One hot encode the selected experts to create an expert mask
928
- # this will be used to easily index which expert is going to be sollicitated
929
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
930
-
931
- # Loop over all available experts in the model and perform the computation on each expert
932
- for expert_idx in range(self.num_experts):
933
- expert_layer = self.experts[expert_idx]
934
- idx, top_x = torch.where(expert_mask[expert_idx])
935
-
936
- # Index the correct hidden states and compute the expert hidden state for
937
- # the current expert. We need to make sure to multiply the output hidden
938
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
939
- current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
940
- current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
941
-
942
- # However `index_add_` only support torch tensors for indexing so we'll use
943
- # the `top_x` tensor here.
944
- final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
945
- final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
946
- return final_hidden_states, router_logits
947
-
948
-
949
- class MiniMaxText01DecoderLayer(nn.Module):
950
- def __init__(self, config: MiniMaxText01Config, layer_idx: int):
951
- super().__init__()
952
- self.config = config
953
- self.hidden_size = config.hidden_size
954
-
955
- self.self_attn = self.build_attn(config, layer_idx)
956
-
957
- self.layer_idx = layer_idx
958
-
959
- self.block_sparse_moe = MiniMaxText01SparseMoeBlock(config)
960
- self.input_layernorm = MiniMaxText01RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
961
- self.post_attention_layernorm = MiniMaxText01RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
962
-
963
- self.postnorm = getattr(config, 'postnorm', False)
964
- self.layernorm_attention_alpha = getattr(config, 'layernorm_linear_attention_alpha', 1) \
965
- if config.attention_type == 0 else getattr(config, 'layernorm_full_attention_alpha', 1)
966
- self.layernorm_attention_beta = getattr(config, 'layernorm_linear_attention_beta', 1) \
967
- if config.attention_type == 0 else getattr(config, 'layernorm_full_attention_beta', 1)
968
- self.layernorm_mlp_alpha = getattr(config, 'layernorm_mlp_alpha', 1)
969
- self.layernorm_mlp_beta = getattr(config, 'layernorm_mlp_beta', 1)
970
-
971
- shared_intermediate = getattr(config, 'shared_intermediate_size', 0)
972
- self.shared_moe = False
973
- if shared_intermediate > 0:
974
- self.shared_moe = True
975
- self.shared_mlp = MiniMaxText01MLP(config)
976
- self.coefficient = torch.nn.Linear(self.hidden_size, 1, bias=False)
977
-
978
- def build_attn(self, config, layer_idx):
979
- if config.attention_type == 0:
980
- Attention_module = MiniMaxText01LightningAttention
981
- else:
982
- Attention_module = MiniMaxText01FlashAttention2
983
-
984
- return Attention_module(
985
- config,
986
- layer_idx
987
- )
988
-
989
- def forward(
990
- self,
991
- hidden_states: torch.Tensor,
992
- attention_mask: Optional[torch.Tensor] = None,
993
- position_ids: Optional[torch.LongTensor] = None,
994
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
995
- output_attentions: Optional[bool] = False,
996
- output_router_logits: Optional[bool] = False,
997
- use_cache: Optional[bool] = False,
998
- slope_rate: Optional[float] = None,
999
- **kwargs,
1000
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1001
- if "padding_mask" in kwargs:
1002
- warnings.warn(
1003
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
1004
- )
1005
- """
1006
- Args:
1007
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1008
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
1009
- `(batch, sequence_length)` where padding elements are indicated by 0.
1010
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1011
- output_attentions (`bool`, *optional*):
1012
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1013
- returned tensors for more detail.
1014
- output_router_logits (`bool`, *optional*):
1015
- Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
1016
- should not be returned during inference.
1017
- use_cache (`bool`, *optional*):
1018
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1019
- (see `past_key_values`).
1020
- """
1021
-
1022
- residual = hidden_states
1023
-
1024
- hidden_states = self.input_layernorm(hidden_states)
1025
- if self.postnorm:
1026
- residual = hidden_states
1027
-
1028
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
1029
- hidden_states=hidden_states,
1030
- position_ids=position_ids,
1031
- attn_mask=attention_mask,
1032
- past_key_value=past_key_value,
1033
- output_attentions=output_attentions,
1034
- use_cache=use_cache,
1035
- slope_rate=slope_rate,
1036
- )
1037
-
1038
- hidden_states = residual * self.layernorm_attention_alpha \
1039
- + hidden_states * self.layernorm_attention_beta
1040
-
1041
- # Fully Connected
1042
- residual = hidden_states
1043
- hidden_states = self.post_attention_layernorm(hidden_states)
1044
- if self.postnorm:
1045
- residual = hidden_states
1046
-
1047
- moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states)
1048
- if self.shared_moe:
1049
- output_mlp = self.shared_mlp(hidden_states)
1050
- weight_fp32 = self.coefficient.weight.float()
1051
- coef = hidden_states.to(torch.float32) @ weight_fp32.T
1052
- coef = torch.nn.functional.sigmoid(coef).to(hidden_states.dtype)
1053
- hidden_states = moe_hidden_states * (1 - coef) + output_mlp * coef
1054
- else:
1055
- hidden_states = moe_hidden_states
1056
-
1057
- hidden_states = residual * self.layernorm_mlp_alpha \
1058
- + hidden_states * self.layernorm_mlp_beta
1059
-
1060
- outputs = (hidden_states,)
1061
-
1062
- if output_attentions:
1063
- outputs += (self_attn_weights,)
1064
-
1065
- if use_cache:
1066
- outputs += (present_key_value,)
1067
-
1068
- if output_router_logits:
1069
- outputs += (router_logits,)
1070
-
1071
- return outputs
1072
-
1073
-
1074
- MIXTRAL_START_DOCSTRING = r"""
1075
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1076
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1077
- etc.)
1078
-
1079
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1080
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1081
- and behavior.
1082
-
1083
- Parameters:
1084
- config ([`MiniMaxText01Config`]):
1085
- Model configuration class with all the parameters of the model. Initializing with a config file does not
1086
- load the weights associated with the model, only the configuration. Check out the
1087
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1088
- """
1089
-
1090
-
1091
- @add_start_docstrings(
1092
- "The bare MiniMaxText01 Model outputting raw hidden-states without any specific head on top.",
1093
- MIXTRAL_START_DOCSTRING,
1094
- )
1095
- # Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->MiniMaxText01
1096
- class MiniMaxText01PreTrainedModel(PreTrainedModel):
1097
- config_class = MiniMaxText01Config
1098
- base_model_prefix = "model"
1099
- supports_gradient_checkpointing = True
1100
- _no_split_modules = ["MiniMaxText01DecoderLayer"]
1101
- _skip_keys_device_placement = "past_key_values"
1102
- _supports_flash_attn_2 = True
1103
- _supports_sdpa = True
1104
-
1105
- def _init_weights(self, module):
1106
- std = self.config.initializer_range
1107
- if isinstance(module, nn.Linear):
1108
- module.weight.data.normal_(mean=0.0, std=std)
1109
- if module.bias is not None:
1110
- module.bias.data.zero_()
1111
- elif isinstance(module, nn.Embedding):
1112
- module.weight.data.normal_(mean=0.0, std=std)
1113
- if module.padding_idx is not None:
1114
- module.weight.data[module.padding_idx].zero_()
1115
-
1116
-
1117
- MIXTRAL_INPUTS_DOCSTRING = r"""
1118
- Args:
1119
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1120
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1121
- it.
1122
-
1123
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1124
- [`PreTrainedTokenizer.__call__`] for details.
1125
-
1126
- [What are input IDs?](../glossary#input-ids)
1127
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1128
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1129
-
1130
- - 1 for tokens that are **not masked**,
1131
- - 0 for tokens that are **masked**.
1132
-
1133
- [What are attention masks?](../glossary#attention-mask)
1134
-
1135
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1136
- [`PreTrainedTokenizer.__call__`] for details.
1137
-
1138
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
1139
- `past_key_values`).
1140
-
1141
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1142
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1143
- information on the default strategy.
1144
-
1145
- - 1 indicates the head is **not masked**,
1146
- - 0 indicates the head is **masked**.
1147
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1148
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1149
- config.n_positions - 1]`.
1150
-
1151
- [What are position IDs?](../glossary#position-ids)
1152
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1153
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
1154
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
1155
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
1156
-
1157
- Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1158
- blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1159
-
1160
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1161
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1162
- `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1163
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1164
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1165
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1166
- model's internal embedding lookup matrix.
1167
- use_cache (`bool`, *optional*):
1168
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1169
- `past_key_values`).
1170
- output_attentions (`bool`, *optional*):
1171
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1172
- tensors for more detail.
1173
- output_hidden_states (`bool`, *optional*):
1174
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1175
- more detail.
1176
- output_router_logits (`bool`, *optional*):
1177
- Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
1178
- should not be returned during inference.
1179
- return_dict (`bool`, *optional*):
1180
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1181
- """
1182
-
1183
-
1184
- @add_start_docstrings(
1185
- "The bare MiniMaxText01 Model outputting raw hidden-states without any specific head on top.",
1186
- MIXTRAL_START_DOCSTRING,
1187
- )
1188
- # Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->MiniMaxText01
1189
- class MiniMaxText01Model(MiniMaxText01PreTrainedModel):
1190
- """
1191
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniMaxText01DecoderLayer`]
1192
-
1193
- Args:
1194
- config: MiniMaxText01Config
1195
- """
1196
-
1197
- def __init__(self, config: MiniMaxText01Config):
1198
- super().__init__(config)
1199
- self.padding_idx = config.pad_token_id
1200
- self.vocab_size = config.vocab_size
1201
-
1202
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1203
- self.layer_types = config.layer_types
1204
- config_copy = copy.deepcopy(config)
1205
-
1206
- self.layers = nn.ModuleList([])
1207
- for i in range(config.num_hidden_layers):
1208
- _config = copy.deepcopy(config)
1209
- if self.layer_types[i] == "linear_attention":
1210
- _config._attn_implementation = 'linear_attention'
1211
- _config.attention_type = 0
1212
- else:
1213
- _config._attn_implementation = config_copy._attn_implementation
1214
- _config.attention_type = 1
1215
- self.layers.append(MiniMaxText01DecoderLayer(_config, i))
1216
-
1217
- self._attn_implementation = config_copy._attn_implementation
1218
- self.norm = MiniMaxText01RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1219
-
1220
- self.gradient_checkpointing = False
1221
- self.slopes = self._build_slope_tensor(config.num_attention_heads)
1222
- # mask
1223
- self._linear_attn_mask = torch.empty(0)
1224
-
1225
- # Initialize weights and apply final processing
1226
- self.post_init()
1227
-
1228
- def get_input_embeddings(self):
1229
- return self.embed_tokens
1230
-
1231
- def set_input_embeddings(self, value):
1232
- self.embed_tokens = value
1233
-
1234
- @staticmethod
1235
- def _build_slope_tensor(n_attention_heads: int):
1236
-
1237
- def get_slopes(n):
1238
-
1239
- def get_slopes_power_of_2(n):
1240
- start = 2 ** (-(2 ** -(math.log2(n) - 3)))
1241
- ratio = start
1242
- return [start * ratio ** i for i in range(n)]
1243
-
1244
- if math.log2(n).is_integer():
1245
- return get_slopes_power_of_2(
1246
- n) # In the paper, we only train models that have 2^a heads for some a. This function has
1247
- else: # some good properties that only occur when the input is a power of 2. To maintain that even
1248
- closest_power_of_2 = 2 ** math.floor(
1249
- math.log2(n)) # when the number of heads is not a power of 2, we use this workaround.
1250
- return (get_slopes_power_of_2(closest_power_of_2)
1251
- + get_slopes(2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
1252
-
1253
- # h, 1, 1
1254
- slopes = torch.tensor(get_slopes(n_attention_heads), dtype=torch.float32).reshape(n_attention_heads, 1, 1)
1255
-
1256
- return slopes
1257
-
1258
- # Ignore copy
1259
- @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
1260
- def forward(
1261
- self,
1262
- input_ids: torch.LongTensor = None,
1263
- attention_mask: Optional[torch.Tensor] = None,
1264
- position_ids: Optional[torch.LongTensor] = None,
1265
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1266
- inputs_embeds: Optional[torch.FloatTensor] = None,
1267
- use_cache: Optional[bool] = None,
1268
- output_attentions: Optional[bool] = None,
1269
- output_hidden_states: Optional[bool] = None,
1270
- output_router_logits: Optional[bool] = None,
1271
- return_dict: Optional[bool] = None,
1272
- ) -> Union[Tuple, MoeModelOutputWithPast]:
1273
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1274
- output_router_logits = (
1275
- output_router_logits if output_router_logits is not None else self.config.output_router_logits
1276
- )
1277
- output_hidden_states = (
1278
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1279
- )
1280
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1281
-
1282
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1283
-
1284
- # retrieve input_ids and inputs_embeds
1285
- if input_ids is not None and inputs_embeds is not None:
1286
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1287
- elif input_ids is not None:
1288
- batch_size, seq_length = input_ids.shape
1289
- default_device = input_ids.device
1290
- elif inputs_embeds is not None:
1291
- batch_size, seq_length, _ = inputs_embeds.shape
1292
- default_device = inputs_embeds.device
1293
- else:
1294
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1295
-
1296
- past_key_values_length = 0
1297
-
1298
- if self.gradient_checkpointing and self.training:
1299
- if use_cache:
1300
- logger.warning_once(
1301
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1302
- )
1303
- use_cache = False
1304
-
1305
- seq_length_with_past = seq_length
1306
- if past_key_values is not None:
1307
- for idx in range(len(past_key_values)):
1308
- if self.layer_types[idx] == "full_attention":
1309
- past_key_values_length = past_key_values[idx][0].shape[-3]
1310
- seq_length_with_past = seq_length_with_past + past_key_values_length
1311
- break
1312
-
1313
- if position_ids is None:
1314
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1315
- position_ids = torch.arange(
1316
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1317
- )
1318
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1319
- else:
1320
- position_ids = position_ids.view(-1, seq_length).long()
1321
-
1322
- if inputs_embeds is None:
1323
- inputs_embeds = self.embed_tokens(input_ids)
1324
-
1325
- if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
1326
- is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1327
- if is_padding_right:
1328
- raise ValueError(
1329
- "You are attempting to perform batched generation with padding_side='right'"
1330
- " this may lead to unexpected behaviour for Flash Attention version of MiniMaxText01. Make sure to "
1331
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1332
- )
1333
- slope_rates = [self.slopes.to(default_device) for _ in range(len(self.layers))]
1334
- hidden_states = inputs_embeds
1335
- # decoder layers
1336
- all_hidden_states = () if output_hidden_states else None
1337
- all_self_attns = () if output_attentions else None
1338
- all_router_logits = () if output_router_logits else None
1339
- next_decoder_cache = () if use_cache else None
1340
-
1341
- for idx, decoder_layer in enumerate(self.layers):
1342
- if output_hidden_states:
1343
- all_hidden_states += (hidden_states,)
1344
-
1345
- past_key_value = (past_key_values[idx] if past_key_values is not None else None)
1346
- attn_mask = attention_mask
1347
- slope_rate = slope_rates[idx]
1348
- slope_rate = slope_rate * (1 - idx / (len(self.layers) - 1) + 1e-5)
1349
- if self.gradient_checkpointing and self.training:
1350
- layer_outputs = self._gradient_checkpointing_func(
1351
- decoder_layer.__call__,
1352
- hidden_states,
1353
- attention_mask,
1354
- position_ids,
1355
- past_key_values,
1356
- output_attentions,
1357
- output_router_logits,
1358
- use_cache,
1359
- )
1360
- else:
1361
- layer_outputs = decoder_layer(
1362
- hidden_states,
1363
- attention_mask=attn_mask,
1364
- position_ids=position_ids,
1365
- past_key_value=past_key_value,
1366
- output_attentions=output_attentions,
1367
- output_router_logits=output_router_logits,
1368
- use_cache=use_cache,
1369
- slope_rate=slope_rate
1370
- )
1371
-
1372
- hidden_states = layer_outputs[0]
1373
-
1374
- if use_cache:
1375
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
1376
-
1377
- if output_attentions:
1378
- all_self_attns += (layer_outputs[1],)
1379
-
1380
- if output_router_logits:
1381
- all_router_logits += (layer_outputs[-1],)
1382
-
1383
- hidden_states = self.norm(hidden_states)
1384
-
1385
- # add hidden states from the last decoder layer
1386
- if output_hidden_states:
1387
- all_hidden_states += (hidden_states,)
1388
- next_cache = next_decoder_cache if use_cache else None
1389
- if not return_dict:
1390
- return tuple(
1391
- v
1392
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
1393
- if v is not None
1394
- )
1395
- return MoeModelOutputWithPast(
1396
- last_hidden_state=hidden_states,
1397
- past_key_values=next_cache,
1398
- hidden_states=all_hidden_states,
1399
- attentions=all_self_attns,
1400
- router_logits=all_router_logits,
1401
- )
1402
-
1403
-
1404
- class MiniMaxText01ForCausalLM(MiniMaxText01PreTrainedModel):
1405
- _tied_weights_keys = ["lm_head.weight"]
1406
-
1407
- def __init__(self, config):
1408
- super().__init__(config)
1409
- self.model = MiniMaxText01Model(config)
1410
- self.vocab_size = config.vocab_size
1411
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1412
- self.router_aux_loss_coef = config.router_aux_loss_coef
1413
- self.num_experts = config.num_local_experts
1414
- self.num_experts_per_tok = config.num_experts_per_tok
1415
- # Initialize weights and apply final processing
1416
- self.post_init()
1417
-
1418
- def get_input_embeddings(self):
1419
- return self.model.embed_tokens
1420
-
1421
- def set_input_embeddings(self, value):
1422
- self.model.embed_tokens = value
1423
-
1424
- def get_output_embeddings(self):
1425
- return self.lm_head
1426
-
1427
- def set_output_embeddings(self, new_embeddings):
1428
- self.lm_head = new_embeddings
1429
-
1430
- def set_decoder(self, decoder):
1431
- self.model = decoder
1432
-
1433
- def get_decoder(self):
1434
- return self.model
1435
-
1436
- @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
1437
- @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1438
- # Ignore copy
1439
- def forward(
1440
- self,
1441
- input_ids: torch.LongTensor = None,
1442
- attention_mask: Optional[torch.Tensor] = None,
1443
- position_ids: Optional[torch.LongTensor] = None,
1444
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1445
- inputs_embeds: Optional[torch.FloatTensor] = None,
1446
- labels: Optional[torch.LongTensor] = None,
1447
- use_cache: Optional[bool] = None,
1448
- output_attentions: Optional[bool] = None,
1449
- output_hidden_states: Optional[bool] = None,
1450
- output_router_logits: Optional[bool] = None,
1451
- return_dict: Optional[bool] = None,
1452
- ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1453
- r"""
1454
- Args:
1455
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1456
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1457
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1458
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1459
-
1460
- Returns:
1461
-
1462
- Example:
1463
-
1464
- ```python
1465
- >>> from transformers import AutoTokenizer, MiniMaxText01ForCausalLM
1466
-
1467
- >>> model = MiniMaxText01ForCausalLM.from_pretrained(PATH_TO_WEIGHTS)
1468
- >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_WEIGHTS)
1469
-
1470
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
1471
- >>> inputs = tokenizer(prompt, return_tensors="pt")
1472
-
1473
- >>> # Generate
1474
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1475
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1476
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1477
- ```"""
1478
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1479
- output_router_logits = (
1480
- output_router_logits if output_router_logits is not None else self.config.output_router_logits
1481
- )
1482
-
1483
- output_hidden_states = (
1484
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1485
- )
1486
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1487
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1488
- outputs = self.model(
1489
- input_ids=input_ids,
1490
- attention_mask=attention_mask,
1491
- position_ids=position_ids,
1492
- past_key_values=past_key_values,
1493
- inputs_embeds=inputs_embeds,
1494
- use_cache=use_cache,
1495
- output_attentions=output_attentions,
1496
- output_hidden_states=output_hidden_states,
1497
- output_router_logits=output_router_logits,
1498
- return_dict=return_dict,
1499
- )
1500
-
1501
- hidden_states = outputs[0]
1502
- logits = self.lm_head(hidden_states)
1503
- logits = logits.float()
1504
-
1505
- loss = None
1506
- if labels is not None:
1507
- # Shift so that tokens < n predict n
1508
- shift_logits = logits[..., :-1, :].contiguous()
1509
- shift_labels = labels[..., 1:].contiguous()
1510
- # Flatten the tokens
1511
- loss_fct = CrossEntropyLoss()
1512
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1513
- shift_labels = shift_labels.view(-1)
1514
- # Enable model parallelism
1515
- shift_labels = shift_labels.to(shift_logits.device)
1516
- loss = loss_fct(shift_logits, shift_labels)
1517
-
1518
- aux_loss = None
1519
- if output_router_logits:
1520
- aux_loss = load_balancing_loss_func(
1521
- outputs.router_logits if return_dict else outputs[-1],
1522
- self.num_experts,
1523
- self.num_experts_per_tok,
1524
- attention_mask,
1525
- )
1526
- if labels is not None:
1527
- loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
1528
-
1529
- if not return_dict:
1530
- output = (logits,) + outputs[1:]
1531
- if output_router_logits:
1532
- output = (aux_loss,) + output
1533
- return (loss,) + output if loss is not None else output
1534
-
1535
- torch.cuda.empty_cache()
1536
- return MoeCausalLMOutputWithPast(
1537
- loss=loss,
1538
- aux_loss=aux_loss,
1539
- logits=logits,
1540
- past_key_values=outputs.past_key_values,
1541
- hidden_states=outputs.hidden_states,
1542
- attentions=outputs.attentions,
1543
- router_logits=outputs.router_logits,
1544
- )
1545
-
1546
- def prepare_inputs_for_generation(
1547
- self,
1548
- input_ids,
1549
- past_key_values=None,
1550
- attention_mask=None,
1551
- inputs_embeds=None,
1552
- **kwargs,
1553
- ):
1554
- if past_key_values:
1555
- input_ids = input_ids[:, -1:]
1556
-
1557
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1558
- if inputs_embeds is not None and past_key_values is None:
1559
- model_inputs = {"inputs_embeds": inputs_embeds}
1560
- else:
1561
- model_inputs = {"input_ids": input_ids}
1562
-
1563
- model_inputs.update({
1564
- "past_key_values": past_key_values,
1565
- "use_cache": kwargs.get("use_cache"),
1566
- "attention_mask": attention_mask,
1567
- })
1568
- return model_inputs
1569
-
1570
- @staticmethod
1571
- def _reorder_cache(past_key_values, beam_idx):
1572
- reordered_past = ()
1573
- for layer_past in past_key_values:
1574
- reordered_past += (
1575
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1576
- )
1577
- return reordered_past
1578
-
1579
-
1580
- @add_start_docstrings(
1581
- """
1582
- The MiniMaxText01 Model transformer with a sequence classification head on top (linear layer).
1583
-
1584
- [`MiniMaxText01ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1585
- (e.g. GPT-2) do.
1586
-
1587
- Since it does classification on the last token, it requires to know the position of the last token. If a
1588
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1589
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1590
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1591
- each row of the batch).
1592
- """,
1593
- MIXTRAL_START_DOCSTRING,
1594
- )
1595
- # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->MiniMaxText01, LLAMA->MIXTRAL
1596
- class MiniMaxText01ForSequenceClassification(MiniMaxText01PreTrainedModel):
1597
- def __init__(self, config):
1598
- super().__init__(config)
1599
- self.num_labels = config.num_labels
1600
- self.model = MiniMaxText01Model(config)
1601
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1602
-
1603
- # Initialize weights and apply final processing
1604
- self.post_init()
1605
-
1606
- def get_input_embeddings(self):
1607
- return self.model.embed_tokens
1608
-
1609
- def set_input_embeddings(self, value):
1610
- self.model.embed_tokens = value
1611
-
1612
- @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
1613
- def forward(
1614
- self,
1615
- input_ids: torch.LongTensor = None,
1616
- attention_mask: Optional[torch.Tensor] = None,
1617
- position_ids: Optional[torch.LongTensor] = None,
1618
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1619
- inputs_embeds: Optional[torch.FloatTensor] = None,
1620
- labels: Optional[torch.LongTensor] = None,
1621
- use_cache: Optional[bool] = None,
1622
- output_attentions: Optional[bool] = None,
1623
- output_hidden_states: Optional[bool] = None,
1624
- return_dict: Optional[bool] = None,
1625
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1626
- r"""
1627
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1628
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1629
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1630
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1631
- """
1632
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1633
-
1634
- transformer_outputs = self.model(
1635
- input_ids,
1636
- attention_mask=attention_mask,
1637
- position_ids=position_ids,
1638
- past_key_values=past_key_values,
1639
- inputs_embeds=inputs_embeds,
1640
- use_cache=use_cache,
1641
- output_attentions=output_attentions,
1642
- output_hidden_states=output_hidden_states,
1643
- return_dict=return_dict,
1644
- )
1645
- hidden_states = transformer_outputs[0]
1646
- logits = self.score(hidden_states)
1647
-
1648
- if input_ids is not None:
1649
- batch_size = input_ids.shape[0]
1650
- else:
1651
- batch_size = inputs_embeds.shape[0]
1652
-
1653
- if self.config.pad_token_id is None and batch_size != 1:
1654
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1655
- if self.config.pad_token_id is None:
1656
- sequence_lengths = -1
1657
- else:
1658
- if input_ids is not None:
1659
- # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1660
- sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1661
- sequence_lengths = sequence_lengths % input_ids.shape[-1]
1662
- sequence_lengths = sequence_lengths.to(logits.device)
1663
- else:
1664
- sequence_lengths = -1
1665
-
1666
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1667
-
1668
- loss = None
1669
- if labels is not None:
1670
- labels = labels.to(logits.device)
1671
- if self.config.problem_type is None:
1672
- if self.num_labels == 1:
1673
- self.config.problem_type = "regression"
1674
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1675
- self.config.problem_type = "single_label_classification"
1676
- else:
1677
- self.config.problem_type = "multi_label_classification"
1678
-
1679
- if self.config.problem_type == "regression":
1680
- loss_fct = MSELoss()
1681
- if self.num_labels == 1:
1682
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1683
- else:
1684
- loss = loss_fct(pooled_logits, labels)
1685
- elif self.config.problem_type == "single_label_classification":
1686
- loss_fct = CrossEntropyLoss()
1687
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1688
- elif self.config.problem_type == "multi_label_classification":
1689
- loss_fct = BCEWithLogitsLoss()
1690
- loss = loss_fct(pooled_logits, labels)
1691
- if not return_dict:
1692
- output = (pooled_logits,) + transformer_outputs[1:]
1693
- return ((loss,) + output) if loss is not None else output
1694
-
1695
- return SequenceClassifierOutputWithPast(
1696
- loss=loss,
1697
- logits=pooled_logits,
1698
- past_key_values=transformer_outputs.past_key_values,
1699
- hidden_states=transformer_outputs.hidden_states,
1700
- attentions=transformer_outputs.attentions,
1701
- )