petil777 commited on
Commit
351d0e7
·
1 Parent(s): d3ef0a4

Delete sr_tp_modeling.py

Browse files
Files changed (1) hide show
  1. sr_tp_modeling.py +0 -888
sr_tp_modeling.py DELETED
@@ -1,888 +0,0 @@
1
- """ PyTorch SRV1 model."""
2
- import sys
3
- import os
4
- from os import path
5
- sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
6
- print(sys.path)
7
- import math
8
- from typing import List, Optional, Tuple, Union
9
-
10
- import torch
11
- import torch.utils.checkpoint
12
- from torch import nn
13
- from torch.nn import CrossEntropyLoss
14
- from transformers.activations import ACT2FN
15
- from transformers import AutoTokenizer, AutoConfig
16
- from .configuration_srv1 import SRV1Config
17
-
18
- from transformers.modeling_outputs import (
19
- BaseModelOutputWithPast,
20
- CausalLMOutputWithPast,
21
- )
22
- from transformers.modeling_utils import PreTrainedModel
23
- from transformers.utils import (
24
- add_start_docstrings,
25
- add_start_docstrings_to_model_forward,
26
- logging,
27
- replace_return_docstrings,
28
- )
29
-
30
- from .layers import (
31
- TensorParallelColumnLinear,
32
- TensorParallelEmbedding,
33
- TensorParallelHead,
34
- TensorParallelRowLinear,
35
- load_layer_norm_no_bias,
36
- )
37
- from .dist import initialize_torch_distributed
38
- from .weights import Weights
39
-
40
- logger = logging.get_logger(__name__)
41
-
42
- _CONFIG_FOR_DOC = SRV1Config
43
-
44
-
45
- # Copied from transformers.models.bart.modeling_bart._make_causal_mask
46
- def _make_causal_mask(
47
- input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
48
- ):
49
- """
50
- Make causal mask used for bi-directional self-attention.
51
- """
52
- bsz, tgt_len = input_ids_shape
53
- mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
54
- mask_cond = torch.arange(mask.size(-1), device=device)
55
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
56
- mask = mask.to(dtype)
57
-
58
- if past_key_values_length > 0:
59
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
60
- return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
61
-
62
-
63
- # Copied from transformers.models.bart.modeling_bart._expand_mask
64
- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
65
- """
66
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
67
- """
68
- bsz, src_len = mask.size()
69
- tgt_len = tgt_len if tgt_len is not None else src_len
70
-
71
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
72
-
73
- inverted_mask = 1.0 - expanded_mask
74
-
75
- return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
76
-
77
-
78
- class SRV1RMSNorm(nn.Module):
79
- def __init__(self, hidden_size, eps=1e-6):
80
- """
81
- SRV1RMSNorm is equivalent to T5LayerNorm
82
- """
83
- super().__init__()
84
- self.weight = nn.Parameter(torch.ones(hidden_size))
85
- self.variance_epsilon = eps
86
-
87
- def forward(self, hidden_states):
88
- input_dtype = hidden_states.dtype
89
- hidden_states = hidden_states.to(torch.float32)
90
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
91
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
92
- return self.weight * hidden_states.to(input_dtype)
93
-
94
-
95
- SRV1RMSNorm.load_no_bias = load_layer_norm_no_bias
96
-
97
-
98
- class SRV1RotaryEmbedding(torch.nn.Module):
99
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
100
- super().__init__()
101
-
102
- self.dim = dim
103
- self.max_position_embeddings = max_position_embeddings
104
- self.base = base
105
- self.inv_freq = self._create_inv_freq(dim=dim, base=base, device=device)
106
-
107
- # Build here to make `torch.jit.trace` work.
108
- self._set_cos_sin_cache(
109
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
110
- )
111
-
112
- def _set_cos_sin_cache(self, seq_len, device, dtype):
113
- self.max_seq_len_cached = seq_len
114
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
115
-
116
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
117
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
118
- emb = torch.cat((freqs, freqs), dim=-1)
119
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
120
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
121
-
122
- def forward(self, x, seq_len=None):
123
- # x: [bs, num_attention_heads, seq_len, head_size]
124
- if seq_len > self.max_seq_len_cached:
125
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
126
-
127
- return (
128
- self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
129
- self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
130
- )
131
-
132
- def _create_inv_freq(self, dim, base, device):
133
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
134
- return inv_freq
135
-
136
- class SRV1RotaryEmbedding(SRV1RotaryEmbedding):
137
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
138
- self.scaling_factor = scaling_factor
139
- super().__init__(dim, max_position_embeddings, base, device)
140
- def _set_cos_sin_cache(self, seq_len, device, dtype):
141
- self.max_seq_len_cached = seq_len
142
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
143
- t = t / self.scaling_factor
144
-
145
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
146
-
147
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
148
- emb = torch.cat((freqs, freqs), dim=-1)
149
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
150
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
151
-
152
- def rotate_half(x):
153
- """Rotates half the hidden dims of the input."""
154
- x1 = x[..., : x.shape[-1] // 2]
155
- x2 = x[..., x.shape[-1] // 2 :]
156
- return torch.cat((-x2, x1), dim=-1)
157
-
158
-
159
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
160
- # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
161
- cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
162
- sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
163
- cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
164
- sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
165
- q_embed = (q * cos) + (rotate_half(q) * sin)
166
- k_embed = (k * cos) + (rotate_half(k) * sin)
167
- return q_embed, k_embed
168
-
169
-
170
- class SRV1MLP(nn.Module):
171
- def __init__(self, prefix, config: SRV1Config, weigths):
172
- super().__init__()
173
- self.gate_proj = TensorParallelColumnLinear.load(
174
- config=config, prefix=f"{prefix}.gate_proj", weights=weigths, bias=False
175
- )
176
- self.up_proj = TensorParallelColumnLinear.load(
177
- config=config, prefix=f"{prefix}.up_proj", weights=weigths, bias=False
178
- )
179
- self.down_proj = TensorParallelRowLinear.load(
180
- config=config, prefix=f"{prefix}.down_proj", weights=weigths, bias=False
181
- )
182
- self.act_fn = ACT2FN[config.hidden_act]
183
-
184
- def forward(self, x):
185
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
186
- return down_proj
187
-
188
-
189
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
190
- """
191
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
192
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
193
- """
194
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
195
- if n_rep == 1:
196
- return hidden_states
197
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
198
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
199
-
200
-
201
- class SRV1Attention(nn.Module):
202
- """Multi-headed attention from 'Attention Is All You Need' paper"""
203
-
204
- def __init__(self, prefix, config: SRV1Config, weights):
205
- super().__init__()
206
- self.config = config
207
- self.hidden_size = config.hidden_size
208
- self.num_heads = config.num_attention_heads
209
- self.head_dim = self.hidden_size // self.num_heads
210
- self.num_key_value_heads = config.num_key_value_heads
211
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
212
- self.max_position_embeddings = config.max_position_embeddings
213
- self.rope_theta = getattr(config, "rope_theta", 10000)
214
-
215
- if (self.head_dim * self.num_heads) != self.hidden_size:
216
- raise ValueError(
217
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
218
- f" and `num_heads`: {self.num_heads})."
219
- )
220
-
221
- # for 1d tensor model parallel
222
- process_group = weights.process_group
223
- self.hidden_size = self.hidden_size // process_group.size()
224
- self.num_heads = self.num_heads // process_group.size()
225
- self.num_key_value_heads = self.num_key_value_heads // process_group.size()
226
-
227
- self.q_proj = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.q_proj", weights=weights, bias=False)
228
- self.k_proj = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.k_proj", weights=weights, bias=False)
229
- self.v_proj = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.v_proj", weights=weights, bias=False)
230
- self.o_proj = TensorParallelRowLinear.load(config, prefix=f"{prefix}.o_proj", weights=weights, bias=False)
231
- if self.config.rope_scaling is not None and self.config.rope_scaling['type'] == "linear":
232
- # Note, Not to use weights.device since rope should be calc on device cpu
233
- # have to model.to(cur_rank) !!!
234
- self.rotary_emb = SRV1RotaryEmbedding(
235
- self.head_dim, self.max_position_embeddings, base=self.rope_theta, scaling_factor=self.config.rope_scaling['factor']
236
- )
237
- else:
238
- self.rotary_emb = SRV1RotaryEmbedding(
239
- self.head_dim, self.max_position_embeddings, base=self.rope_theta
240
- )
241
-
242
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
243
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
244
-
245
- def forward(
246
- self,
247
- hidden_states: torch.Tensor,
248
- attention_mask: Optional[torch.Tensor] = None,
249
- position_ids: Optional[torch.LongTensor] = None,
250
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
251
- output_attentions: bool = False,
252
- use_cache: bool = False,
253
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
254
- bsz, q_len, _ = hidden_states.size()
255
-
256
- query_states = self.q_proj(hidden_states)
257
- key_states = self.k_proj(hidden_states)
258
- value_states = self.v_proj(hidden_states)
259
-
260
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
261
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
262
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
263
-
264
- kv_seq_len = key_states.shape[-2]
265
- if past_key_value is not None:
266
- kv_seq_len += past_key_value[0].shape[-2]
267
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
268
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
269
-
270
- if past_key_value is not None:
271
- # reuse k, v, self_attention
272
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
273
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
274
-
275
- past_key_value = (key_states, value_states) if use_cache else None
276
-
277
- # repeat k/v heads if n_kv_heads < n_heads
278
- key_states = repeat_kv(key_states, self.num_key_value_groups)
279
- value_states = repeat_kv(value_states, self.num_key_value_groups)
280
-
281
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
282
-
283
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
284
- raise ValueError(
285
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
286
- f" {attn_weights.size()}"
287
- )
288
-
289
- if attention_mask is not None:
290
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
291
- raise ValueError(
292
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
293
- )
294
- attn_weights = attn_weights + attention_mask
295
-
296
- # upcast attention to fp32
297
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
298
- attn_output = torch.matmul(attn_weights, value_states)
299
-
300
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
301
- raise ValueError(
302
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
303
- f" {attn_output.size()}"
304
- )
305
-
306
- attn_output = attn_output.transpose(1, 2).contiguous()
307
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
308
- attn_output = self.o_proj(attn_output)
309
-
310
- if not output_attentions:
311
- attn_weights = None
312
-
313
- return attn_output, attn_weights, past_key_value
314
-
315
-
316
- class SRV1DecoderLayer(nn.Module):
317
- def __init__(self, prefix, config: SRV1Config, weights):
318
- super().__init__()
319
- self.hidden_size = config.hidden_size
320
- self.self_attn = SRV1Attention(prefix=f"{prefix}.self_attn", config=config, weights=weights)
321
- self.mlp = SRV1MLP(prefix=f"{prefix}.mlp", config=config, weigths=weights)
322
- self.input_layernorm = SRV1RMSNorm.load_no_bias(
323
- prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
324
- )
325
- self.post_attention_layernorm = SRV1RMSNorm.load_no_bias(
326
- prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=config.rms_norm_eps
327
- )
328
-
329
- def forward(
330
- self,
331
- hidden_states: torch.Tensor,
332
- attention_mask: Optional[torch.Tensor] = None,
333
- position_ids: Optional[torch.LongTensor] = None,
334
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
335
- output_attentions: Optional[bool] = False,
336
- use_cache: Optional[bool] = False,
337
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
338
- """
339
- Args:
340
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
341
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
342
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
343
- output_attentions (`bool`, *optional*):
344
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
345
- returned tensors for more detail.
346
- use_cache (`bool`, *optional*):
347
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
348
- (see `past_key_values`).
349
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
350
- """
351
-
352
- residual = hidden_states
353
-
354
- hidden_states = self.input_layernorm(hidden_states)
355
-
356
- # Self Attention
357
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
358
- hidden_states=hidden_states,
359
- attention_mask=attention_mask,
360
- position_ids=position_ids,
361
- past_key_value=past_key_value,
362
- output_attentions=output_attentions,
363
- use_cache=use_cache,
364
- )
365
- hidden_states = residual + hidden_states
366
-
367
- # Fully Connected
368
- residual = hidden_states
369
- hidden_states = self.post_attention_layernorm(hidden_states)
370
- hidden_states = self.mlp(hidden_states)
371
- hidden_states = residual + hidden_states
372
-
373
- outputs = (hidden_states,)
374
-
375
- if output_attentions:
376
- outputs += (self_attn_weights,)
377
-
378
- if use_cache:
379
- outputs += (present_key_value,)
380
-
381
- return outputs
382
-
383
-
384
- SRV1_START_DOCSTRING = r"""
385
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
386
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
387
- etc.)
388
-
389
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
390
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
391
- and behavior.
392
-
393
- Parameters:
394
- config ([`SRV1Config`]):
395
- Model configuration class with all the parameters of the model. Initializing with a config file does not
396
- load the weights associated with the model, only the configuration. Check out the
397
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
398
- """
399
-
400
-
401
- @add_start_docstrings(
402
- "The bare SRV1 Model outputting raw hidden-states without any specific head on top.",
403
- SRV1_START_DOCSTRING,
404
- )
405
- class SRV1PreTrainedModel(PreTrainedModel):
406
- config_class = SRV1Config
407
- base_model_prefix = "model"
408
- supports_gradient_checkpointing = True
409
- _no_split_modules = ["SRV1DecoderLayer"]
410
- _skip_keys_device_placement = "past_key_values"
411
-
412
- def _init_weights(self, module):
413
- std = self.config.initializer_range
414
- if isinstance(module, nn.Linear):
415
- module.weight.data.normal_(mean=0.0, std=std)
416
- if module.bias is not None:
417
- module.bias.data.zero_()
418
- elif isinstance(module, nn.Embedding):
419
- module.weight.data.normal_(mean=0.0, std=std)
420
- if module.padding_idx is not None:
421
- module.weight.data[module.padding_idx].zero_()
422
-
423
- def _set_gradient_checkpointing(self, module, value=False):
424
- if isinstance(module, SRV1Model):
425
- module.gradient_checkpointing = value
426
-
427
-
428
- SRV1_INPUTS_DOCSTRING = r"""
429
- Args:
430
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
431
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
432
- it.
433
-
434
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
435
- [`PreTrainedTokenizer.__call__`] for details.
436
-
437
- [What are input IDs?](../glossary#input-ids)
438
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
439
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
440
-
441
- - 1 for tokens that are **not masked**,
442
- - 0 for tokens that are **masked**.
443
-
444
- [What are attention masks?](../glossary#attention-mask)
445
-
446
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
447
- [`PreTrainedTokenizer.__call__`] for details.
448
-
449
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
450
- `past_key_values`).
451
-
452
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
453
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
454
- information on the default strategy.
455
-
456
- - 1 indicates the head is **not masked**,
457
- - 0 indicates the head is **masked**.
458
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
459
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
460
- config.n_positions - 1]`.
461
-
462
- [What are position IDs?](../glossary#position-ids)
463
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
464
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
465
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
466
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
467
-
468
- Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
469
- blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
470
-
471
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
472
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
473
- `decoder_input_ids` of shape `(batch_size, sequence_length)`.
474
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
475
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
476
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
477
- model's internal embedding lookup matrix.
478
- use_cache (`bool`, *optional*):
479
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
480
- `past_key_values`).
481
- output_attentions (`bool`, *optional*):
482
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
483
- tensors for more detail.
484
- output_hidden_states (`bool`, *optional*):
485
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
486
- more detail.
487
- return_dict (`bool`, *optional*):
488
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
489
- """
490
-
491
-
492
- @add_start_docstrings(
493
- "The bare SRV1 Model outputting raw hidden-states without any specific head on top.",
494
- SRV1_START_DOCSTRING,
495
- )
496
- class SRV1Model(SRV1PreTrainedModel):
497
- """
498
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SRV1DecoderLayer`]
499
-
500
- Args:
501
- config: SRV1Config
502
- """
503
-
504
- def __init__(self, config: SRV1Config, weights):
505
- super().__init__(config)
506
- self.embed_tokens = TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights)
507
- self.layers = nn.ModuleList(
508
- [
509
- SRV1DecoderLayer(prefix=f"model.layers.{_}", config=config, weights=weights)
510
- for _ in range(config.num_hidden_layers)
511
- ]
512
- )
513
- # self.norm = SRV1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
514
- self.norm = SRV1RMSNorm.load_no_bias(prefix=f"model.norm", weights=weights, eps=config.rms_norm_eps)
515
- self.gradient_checkpointing = False
516
- # Initialize weights and apply final processing
517
- self.post_init()
518
-
519
- def get_input_embeddings(self):
520
- return self.embed_tokens
521
-
522
- def set_input_embeddings(self, value):
523
- self.embed_tokens = value
524
-
525
- # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
526
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
527
- # create causal mask
528
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
529
- combined_attention_mask = None
530
- if input_shape[-1] > 1:
531
- combined_attention_mask = _make_causal_mask(
532
- input_shape,
533
- inputs_embeds.dtype,
534
- device=inputs_embeds.device,
535
- past_key_values_length=past_key_values_length,
536
- )
537
-
538
- if attention_mask is not None:
539
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
540
- expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
541
- inputs_embeds.device
542
- )
543
- combined_attention_mask = (
544
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
545
- )
546
-
547
- return combined_attention_mask
548
-
549
- @add_start_docstrings_to_model_forward(SRV1_INPUTS_DOCSTRING)
550
- def forward(
551
- self,
552
- input_ids: torch.LongTensor = None,
553
- attention_mask: Optional[torch.Tensor] = None,
554
- position_ids: Optional[torch.LongTensor] = None,
555
- past_key_values: Optional[List[torch.FloatTensor]] = None,
556
- inputs_embeds: Optional[torch.FloatTensor] = None,
557
- use_cache: Optional[bool] = None,
558
- output_attentions: Optional[bool] = None,
559
- output_hidden_states: Optional[bool] = None,
560
- return_dict: Optional[bool] = None,
561
- ) -> Union[Tuple, BaseModelOutputWithPast]:
562
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
563
- output_hidden_states = (
564
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
565
- )
566
- use_cache = use_cache if use_cache is not None else self.config.use_cache
567
-
568
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
569
-
570
- # retrieve input_ids and inputs_embeds
571
- if input_ids is not None and inputs_embeds is not None:
572
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
573
- elif input_ids is not None:
574
- batch_size, seq_length = input_ids.shape
575
- elif inputs_embeds is not None:
576
- batch_size, seq_length, _ = inputs_embeds.shape
577
- else:
578
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
579
-
580
- seq_length_with_past = seq_length
581
- past_key_values_length = 0
582
-
583
- if past_key_values is not None:
584
- past_key_values_length = past_key_values[0][0].shape[2]
585
- seq_length_with_past = seq_length_with_past + past_key_values_length
586
-
587
- if position_ids is None:
588
- device = input_ids.device if input_ids is not None else inputs_embeds.device
589
- position_ids = torch.arange(
590
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
591
- )
592
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
593
- else:
594
- position_ids = position_ids.view(-1, seq_length).long()
595
-
596
- if inputs_embeds is None:
597
- inputs_embeds = self.embed_tokens(input_ids)
598
- # embed positions
599
- if attention_mask is None:
600
- attention_mask = torch.ones(
601
- (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
602
- )
603
- attention_mask = self._prepare_decoder_attention_mask(
604
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
605
- )
606
-
607
- hidden_states = inputs_embeds
608
-
609
- if self.gradient_checkpointing and self.training:
610
- if use_cache:
611
- logger.warning_once(
612
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
613
- )
614
- use_cache = False
615
-
616
- # decoder layers
617
- all_hidden_states = () if output_hidden_states else None
618
- all_self_attns = () if output_attentions else None
619
- next_decoder_cache = () if use_cache else None
620
-
621
- for idx, decoder_layer in enumerate(self.layers):
622
- if output_hidden_states:
623
- all_hidden_states += (hidden_states,)
624
-
625
- past_key_value = past_key_values[idx] if past_key_values is not None else None
626
-
627
- if self.gradient_checkpointing and self.training:
628
-
629
- def create_custom_forward(module):
630
- def custom_forward(*inputs):
631
- # None for past_key_value
632
- return module(*inputs, past_key_value, output_attentions)
633
-
634
- return custom_forward
635
-
636
- layer_outputs = torch.utils.checkpoint.checkpoint(
637
- create_custom_forward(decoder_layer),
638
- hidden_states,
639
- attention_mask,
640
- position_ids,
641
- )
642
- else:
643
- layer_outputs = decoder_layer(
644
- hidden_states,
645
- attention_mask=attention_mask,
646
- position_ids=position_ids,
647
- past_key_value=past_key_value,
648
- output_attentions=output_attentions,
649
- use_cache=use_cache,
650
- )
651
-
652
- hidden_states = layer_outputs[0]
653
-
654
- if use_cache:
655
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
656
-
657
- if output_attentions:
658
- all_self_attns += (layer_outputs[1],)
659
-
660
- hidden_states = self.norm(hidden_states)
661
-
662
- # add hidden states from the last decoder layer
663
- if output_hidden_states:
664
- all_hidden_states += (hidden_states,)
665
-
666
- next_cache = next_decoder_cache if use_cache else None
667
- if not return_dict:
668
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
669
- return BaseModelOutputWithPast(
670
- last_hidden_state=hidden_states,
671
- past_key_values=next_cache,
672
- hidden_states=all_hidden_states,
673
- attentions=all_self_attns,
674
- )
675
-
676
-
677
- class SRV1ForCausalLM(SRV1PreTrainedModel):
678
- _tied_weights_keys = ["lm_head.weight"]
679
-
680
- def __init__(self, config, weights):
681
- super().__init__(config)
682
- self.model = SRV1Model(config, weights)
683
- self.lm_head = TensorParallelHead.load(config, prefix="lm_head", weights=weights)
684
- # Initialize weights and apply final processing
685
- self.post_init()
686
-
687
- def get_input_embeddings(self):
688
- return self.model.embed_tokens
689
-
690
- def set_input_embeddings(self, value):
691
- self.model.embed_tokens = value
692
-
693
- def get_output_embeddings(self):
694
- return self.lm_head
695
-
696
- def set_output_embeddings(self, new_embeddings):
697
- self.lm_head = new_embeddings
698
-
699
- def set_decoder(self, decoder):
700
- self.model = decoder
701
-
702
- def get_decoder(self):
703
- return self.model
704
-
705
- @add_start_docstrings_to_model_forward(SRV1_INPUTS_DOCSTRING)
706
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
707
- def forward(
708
- self,
709
- input_ids: torch.LongTensor = None,
710
- attention_mask: Optional[torch.Tensor] = None,
711
- position_ids: Optional[torch.LongTensor] = None,
712
- past_key_values: Optional[List[torch.FloatTensor]] = None,
713
- inputs_embeds: Optional[torch.FloatTensor] = None,
714
- labels: Optional[torch.LongTensor] = None,
715
- use_cache: Optional[bool] = None,
716
- output_attentions: Optional[bool] = None,
717
- output_hidden_states: Optional[bool] = None,
718
- return_dict: Optional[bool] = None,
719
- ) -> Union[Tuple, CausalLMOutputWithPast]:
720
- r"""
721
- Args:
722
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
723
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
724
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
725
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
726
-
727
- Returns:
728
-
729
- Example:
730
-
731
- ```python
732
- >>> from transformers import AutoTokenizer, SRV1ForCausalLM
733
-
734
- >>> model = SRV1ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
735
- >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
736
-
737
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
738
- >>> inputs = tokenizer(prompt, return_tensors="pt")
739
-
740
- >>> # Generate
741
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
742
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
743
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
744
- ```"""
745
-
746
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
747
- output_hidden_states = (
748
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
749
- )
750
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
751
-
752
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
753
- outputs = self.model(
754
- input_ids=input_ids,
755
- attention_mask=attention_mask,
756
- position_ids=position_ids,
757
- past_key_values=past_key_values,
758
- inputs_embeds=inputs_embeds,
759
- use_cache=use_cache,
760
- output_attentions=output_attentions,
761
- output_hidden_states=output_hidden_states,
762
- return_dict=return_dict,
763
- )
764
-
765
- hidden_states = outputs[0]
766
- logits = self.lm_head(hidden_states)
767
- logits = logits.float()
768
-
769
- loss = None
770
- if labels is not None:
771
- # Shift so that tokens < n predict n
772
- shift_logits = logits[..., :-1, :].contiguous()
773
- shift_labels = labels[..., 1:].contiguous()
774
- # Flatten the tokens
775
- loss_fct = CrossEntropyLoss()
776
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
777
- shift_labels = shift_labels.view(-1)
778
- # Enable model parallelism
779
- shift_labels = shift_labels.to(shift_logits.device)
780
- loss = loss_fct(shift_logits, shift_labels)
781
-
782
- if not return_dict:
783
- output = (logits,) + outputs[1:]
784
- return (loss,) + output if loss is not None else output
785
-
786
- return CausalLMOutputWithPast(
787
- loss=loss,
788
- logits=logits,
789
- past_key_values=outputs.past_key_values,
790
- hidden_states=outputs.hidden_states,
791
- attentions=outputs.attentions,
792
- )
793
-
794
- def prepare_inputs_for_generation(
795
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
796
- ):
797
- if past_key_values:
798
- input_ids = input_ids[:, -1:]
799
-
800
- position_ids = kwargs.get("position_ids", None)
801
- if attention_mask is not None and position_ids is None:
802
- # create position_ids on the fly for batch generation
803
- position_ids = attention_mask.long().cumsum(-1) - 1
804
- position_ids.masked_fill_(attention_mask == 0, 1)
805
- if past_key_values:
806
- position_ids = position_ids[:, -1].unsqueeze(-1)
807
-
808
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
809
- if inputs_embeds is not None and past_key_values is None:
810
- model_inputs = {"inputs_embeds": inputs_embeds}
811
- else:
812
- model_inputs = {"input_ids": input_ids}
813
-
814
- model_inputs.update(
815
- {
816
- "position_ids": position_ids,
817
- "past_key_values": past_key_values,
818
- "use_cache": kwargs.get("use_cache"),
819
- "attention_mask": attention_mask,
820
- }
821
- )
822
- return model_inputs
823
-
824
- @staticmethod
825
- def _reorder_cache(past_key_values, beam_idx):
826
- reordered_past = ()
827
- for layer_past in past_key_values:
828
- reordered_past += (
829
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
830
- )
831
- return reordered_past
832
-
833
- class SRV1ForCausalLMParallel(SRV1ForCausalLM):
834
- def __init__(self, config, **kwargs):
835
- model_id = kwargs.get("local_path", None)
836
- if model_id is None:
837
- model_id = kwargs.get("pretrained_model_name_or_path", None)
838
- revision = kwargs.get("revision", None)
839
- trust_remote_code = kwargs.get("trust_remote_code", False)
840
- quantize = kwargs.get("quantize", None)
841
- dtype = kwargs.get("dtype", None)
842
- print("Start initializing...")
843
- self.process_group, rank, world_size = initialize_torch_distributed()
844
- print(f"RANK[{rank}]: Distributed Initialize Success")
845
- if torch.cuda.is_available():
846
- device = torch.device(f"cuda:{rank}")
847
- dtype = torch.float16 if dtype is None else dtype
848
- print(f"Use dtype {dtype}")
849
- else:
850
- raise NotImplementedError("Flash is only available on GPU")
851
-
852
- print(f"Will read model dir {model_id}")
853
- self.tokenizer = AutoTokenizer.from_pretrained(
854
- model_id,
855
- revision=revision,
856
- padding_side="left",
857
- truncation_side="left",
858
- trust_remote_code=trust_remote_code,
859
- )
860
- # config already defined in from_pretrained
861
- # config = SRV1Config.from_pretrained(model_id, revision=revision, trust_remote_code=trust_remote_code)
862
- config.quantize = quantize
863
- torch.distributed.barrier(group=self.process_group)
864
- import glob
865
- filenames = glob.glob(f"{model_id}/*.safetensors")
866
- print(f"Will read filename {filenames}")
867
- weights = Weights(filenames=filenames, device=device, dtype=dtype, process_group=self.process_group)
868
- print(f"RANK[{rank}]: Loaded Weights success. device:{device}")
869
-
870
- torch.distributed.barrier(group=self.process_group)
871
- super(SRV1ForCausalLMParallel, self).__init__(
872
- config=config,
873
- weights=weights
874
- )
875
- print(f"RANK[{rank}]: parallel load success")
876
- torch.distributed.barrier(group=self.process_group)
877
-
878
- @classmethod
879
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, config=None, **kwargs):
880
- config_path = config if config is not None else pretrained_model_name_or_path
881
-
882
- config = cls.config_class.from_pretrained(
883
- config_path,
884
- **kwargs,
885
- )
886
- kwargs.update({"pretrained_model_name_or_path": pretrained_model_name_or_path})
887
- model = cls(config, *model_args, **kwargs)
888
- return model