Goekdeniz-Guelmez commited on
Commit
82c60f3
·
verified ·
1 Parent(s): 29536a9

Delete modeling_dream.py

Browse files
Files changed (1) hide show
  1. modeling_dream.py +0 -824
modeling_dream.py DELETED
@@ -1,824 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
- # and OPT and Qwen implementations in this library. It has been modified from its
6
- # original forms to accommodate minor architectural differences compared
7
- # to GPT-NeoX and OPT and Qwen used by the Meta AI and Qwen team that trained the model.
8
- #
9
- # Licensed under the Apache License, Version 2.0 (the "License");
10
- # you may not use this file except in compliance with the License.
11
- # You may obtain a copy of the License at
12
- #
13
- # http://www.apache.org/licenses/LICENSE-2.0
14
- #
15
- # Unless required by applicable law or agreed to in writing, software
16
- # distributed under the License is distributed on an "AS IS" BASIS,
17
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
- # See the License for the specific language governing permissions and
19
- # limitations under the License.
20
- """PyTorch Dream model."""
21
-
22
- import math
23
- from typing import List, Optional, Tuple, Union
24
- import os
25
- import torch
26
- import torch.utils.checkpoint
27
- from torch import nn
28
-
29
- from transformers.activations import ACT2FN
30
- from transformers.cache_utils import Cache, DynamicCache
31
- from transformers.modeling_outputs import (
32
- BaseModelOutput,
33
- MaskedLMOutput,
34
- )
35
- from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
36
- from transformers.modeling_utils import PreTrainedModel
37
- from transformers.utils import (
38
- add_start_docstrings,
39
- add_start_docstrings_to_model_forward,
40
- is_flash_attn_2_available,
41
- is_flash_attn_greater_or_equal_2_10,
42
- logging,
43
- )
44
- from transformers import PretrainedConfig
45
- from .configuration_dream import DreamConfig
46
- from .generation_utils import DreamGenerationMixin, DreamGenerationConfig
47
-
48
- if is_flash_attn_2_available():
49
- from transformers.modeling_flash_attention_utils import _flash_attention_forward
50
-
51
-
52
- logger = logging.get_logger(__name__)
53
-
54
-
55
- _CHECKPOINT_FOR_DOC = "Dream-7B"
56
- _CONFIG_FOR_DOC = "DreamConfig"
57
-
58
-
59
- # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Dream
60
- class DreamRMSNorm(nn.Module):
61
- def __init__(self, hidden_size, eps=1e-6):
62
- """
63
- DreamRMSNorm is equivalent to T5LayerNorm
64
- """
65
- super().__init__()
66
- self.weight = nn.Parameter(torch.ones(hidden_size))
67
- self.variance_epsilon = eps
68
-
69
- def forward(self, hidden_states):
70
- input_dtype = hidden_states.dtype
71
- hidden_states = hidden_states.to(torch.float32)
72
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
73
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
74
- return self.weight * hidden_states.to(input_dtype)
75
-
76
- def extra_repr(self):
77
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
78
-
79
-
80
- # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Dream
81
- class DreamRotaryEmbedding(nn.Module):
82
- def __init__(
83
- self,
84
- dim=None,
85
- max_position_embeddings=2048,
86
- base=10000,
87
- device=None,
88
- scaling_factor=1.0,
89
- rope_type="default",
90
- config: Optional[DreamConfig] = None,
91
- ):
92
- super().__init__()
93
- # TODO (joao): remove the `if` below, only used for BC
94
- self.rope_kwargs = {}
95
- if config is None:
96
- logger.warning_once(
97
- "`DreamRotaryEmbedding` can now be fully parameterized by passing the model config through the "
98
- "`config` argument. All other arguments will be removed in v4.46"
99
- )
100
- self.rope_kwargs = {
101
- "rope_type": rope_type,
102
- "factor": scaling_factor,
103
- "dim": dim,
104
- "base": base,
105
- "max_position_embeddings": max_position_embeddings,
106
- }
107
- self.rope_type = rope_type
108
- self.max_seq_len_cached = max_position_embeddings
109
- self.original_max_seq_len = max_position_embeddings
110
- else:
111
- # BC: "rope_type" was originally "type"
112
- if config.rope_scaling is not None:
113
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
114
- else:
115
- self.rope_type = "default"
116
- self.max_seq_len_cached = config.max_position_embeddings
117
- self.original_max_seq_len = config.max_position_embeddings
118
-
119
- self.config = config
120
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
121
-
122
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
123
- self.register_buffer("inv_freq", inv_freq, persistent=False)
124
- self.original_inv_freq = self.inv_freq
125
-
126
- def reset_parameters(self):
127
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, self.inv_freq.device, **self.rope_kwargs)
128
- self.register_buffer("inv_freq", inv_freq, persistent=False)
129
- self.original_inv_freq = self.inv_freq
130
-
131
-
132
- def _dynamic_frequency_update(self, position_ids, device):
133
- """
134
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
135
- 1 - growing beyond the cached sequence length (allow scaling)
136
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
137
- """
138
- seq_len = torch.max(position_ids) + 1
139
- if seq_len > self.max_seq_len_cached: # growth
140
- inv_freq, self.attention_scaling = self.rope_init_fn(
141
- self.config, device, seq_len=seq_len, **self.rope_kwargs
142
- )
143
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
144
- self.max_seq_len_cached = seq_len
145
-
146
- if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
147
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
148
- self.max_seq_len_cached = self.original_max_seq_len
149
-
150
- @torch.no_grad()
151
- def forward(self, x, position_ids):
152
- if "dynamic" in self.rope_type:
153
- self._dynamic_frequency_update(position_ids, device=x.device)
154
-
155
- # Core RoPE block
156
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
157
- position_ids_expanded = position_ids[:, None, :].float()
158
- # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
159
- device_type = x.device.type
160
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
161
- with torch.autocast(device_type=device_type, enabled=False):
162
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
163
- emb = torch.cat((freqs, freqs), dim=-1)
164
- cos = emb.cos()
165
- sin = emb.sin()
166
-
167
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
168
- cos = cos * self.attention_scaling
169
- sin = sin * self.attention_scaling
170
-
171
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
172
-
173
-
174
- # Copied from transformers.models.llama.modeling_llama.rotate_half
175
- def rotate_half(x):
176
- """Rotates half the hidden dims of the input."""
177
- x1 = x[..., : x.shape[-1] // 2]
178
- x2 = x[..., x.shape[-1] // 2 :]
179
- return torch.cat((-x2, x1), dim=-1)
180
-
181
-
182
- # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
183
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
184
- """Applies Rotary Position Embedding to the query and key tensors.
185
-
186
- Args:
187
- q (`torch.Tensor`): The query tensor.
188
- k (`torch.Tensor`): The key tensor.
189
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
190
- sin (`torch.Tensor`): The sine part of the rotary embedding.
191
- position_ids (`torch.Tensor`, *optional*):
192
- Deprecated and unused.
193
- unsqueeze_dim (`int`, *optional*, defaults to 1):
194
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
195
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
196
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
197
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
198
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
199
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
200
- Returns:
201
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
202
- """
203
- cos = cos.unsqueeze(unsqueeze_dim)
204
- sin = sin.unsqueeze(unsqueeze_dim)
205
- q_embed = (q * cos) + (rotate_half(q) * sin)
206
- k_embed = (k * cos) + (rotate_half(k) * sin)
207
- return q_embed, k_embed
208
-
209
-
210
- # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Dream
211
- class DreamMLP(nn.Module):
212
- def __init__(self, config):
213
- super().__init__()
214
- self.hidden_size = config.hidden_size
215
- self.intermediate_size = config.intermediate_size
216
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
217
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
218
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
219
- self.act_fn = ACT2FN[config.hidden_act]
220
-
221
- def forward(self, hidden_state):
222
- return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
223
-
224
-
225
- # Copied from transformers.models.llama.modeling_llama.repeat_kv
226
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
227
- """
228
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
229
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
230
- """
231
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
232
- if n_rep == 1:
233
- return hidden_states
234
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
235
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
236
-
237
-
238
- class DreamAttention(nn.Module):
239
- """
240
- Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
241
- and "Generating Long Sequences with Sparse Transformers".
242
- """
243
-
244
- def __init__(self, config: DreamConfig, layer_idx: Optional[int] = None):
245
- super().__init__()
246
- self.config = config
247
- self.layer_idx = layer_idx
248
- if layer_idx is None:
249
- logger.warning_once(
250
- f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
251
- "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
252
- "when creating this class."
253
- )
254
-
255
- self.hidden_size = config.hidden_size
256
- self.num_heads = config.num_attention_heads
257
- self.head_dim = self.hidden_size // self.num_heads
258
- self.num_key_value_heads = config.num_key_value_heads
259
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
260
- self.max_position_embeddings = config.max_position_embeddings
261
- self.rope_theta = config.rope_theta
262
- self.is_causal = False
263
- self.attention_dropout = config.attention_dropout
264
-
265
- if (self.head_dim * self.num_heads) != self.hidden_size:
266
- raise ValueError(
267
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
268
- f" and `num_heads`: {self.num_heads})."
269
- )
270
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
271
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
272
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
273
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
274
-
275
- self.rotary_emb = DreamRotaryEmbedding(config=self.config)
276
-
277
- def forward(
278
- self,
279
- hidden_states: torch.Tensor,
280
- attention_mask: Optional[torch.Tensor] = None,
281
- position_ids: Optional[torch.LongTensor] = None,
282
- past_key_value: Optional[Cache] = None,
283
- output_attentions: bool = False,
284
- use_cache: bool = False,
285
- cache_position: Optional[torch.LongTensor] = None,
286
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
287
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
288
- bsz, q_len, _ = hidden_states.size()
289
-
290
- query_states = self.q_proj(hidden_states)
291
- key_states = self.k_proj(hidden_states)
292
- value_states = self.v_proj(hidden_states)
293
-
294
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
295
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
296
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
297
-
298
- if position_embeddings is None:
299
- logger.warning_once(
300
- "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
301
- "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
302
- "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
303
- "removed and `position_embeddings` will be mandatory."
304
- )
305
- cos, sin = self.rotary_emb(value_states, position_ids)
306
- else:
307
- cos, sin = position_embeddings
308
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
309
-
310
- if past_key_value is not None:
311
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
312
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
313
-
314
- # repeat k/v heads if n_kv_heads < n_heads
315
- key_states = repeat_kv(key_states, self.num_key_value_groups)
316
- value_states = repeat_kv(value_states, self.num_key_value_groups)
317
-
318
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
319
- if attention_mask is not None: # no matter the length, we just slice it
320
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
321
- attn_weights = attn_weights + causal_mask
322
-
323
- # upcast attention to fp32
324
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
325
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
326
- attn_output = torch.matmul(attn_weights, value_states)
327
-
328
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
329
- raise ValueError(
330
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
331
- f" {attn_output.size()}"
332
- )
333
-
334
- attn_output = attn_output.transpose(1, 2).contiguous()
335
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
336
-
337
- attn_output = self.o_proj(attn_output)
338
-
339
- if not output_attentions:
340
- attn_weights = None
341
-
342
- return attn_output, attn_weights, past_key_value
343
-
344
-
345
- class DreamSdpaAttention(DreamAttention):
346
- """
347
- Dream attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
348
- `DreamAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
349
- SDPA API.
350
- """
351
-
352
- # Adapted from DreamAttention.forward
353
- def forward(
354
- self,
355
- hidden_states: torch.Tensor,
356
- attention_mask: Optional[torch.Tensor] = None,
357
- position_ids: Optional[torch.LongTensor] = None,
358
- past_key_value: Optional[Cache] = None,
359
- output_attentions: bool = False,
360
- use_cache: bool = False,
361
- cache_position: Optional[torch.LongTensor] = None,
362
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
363
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
364
- if output_attentions:
365
- # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
366
- logger.warning_once(
367
- "DreamModel is using DreamSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
368
- 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
369
- )
370
- return super().forward(
371
- hidden_states=hidden_states,
372
- attention_mask=attention_mask,
373
- position_ids=position_ids,
374
- past_key_value=past_key_value,
375
- output_attentions=output_attentions,
376
- use_cache=use_cache,
377
- )
378
-
379
- bsz, q_len, _ = hidden_states.size()
380
-
381
- query_states = self.q_proj(hidden_states)
382
- key_states = self.k_proj(hidden_states)
383
- value_states = self.v_proj(hidden_states)
384
-
385
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
386
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
387
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
388
-
389
- if position_embeddings is None:
390
- logger.warning_once(
391
- "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
392
- "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
393
- "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
394
- "removed and `position_embeddings` will be mandatory."
395
- )
396
- cos, sin = self.rotary_emb(value_states, position_ids)
397
- else:
398
- cos, sin = position_embeddings
399
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
400
-
401
- if past_key_value is not None:
402
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
403
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
404
-
405
- key_states = repeat_kv(key_states, self.num_key_value_groups)
406
- value_states = repeat_kv(value_states, self.num_key_value_groups)
407
-
408
- # causal_mask = attention_mask
409
- # if attention_mask is not None: # no matter the length, we just slice it
410
- # causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
411
-
412
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
413
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
414
- if query_states.device.type == "cuda" and attention_mask is not None:
415
- query_states = query_states.contiguous()
416
- key_states = key_states.contiguous()
417
- value_states = value_states.contiguous()
418
-
419
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
420
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
421
- # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
422
- # is_causal = True if causal_mask is None and q_len > 1 else False
423
-
424
- attn_output = torch.nn.functional.scaled_dot_product_attention(
425
- query_states,
426
- key_states,
427
- value_states,
428
- attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None,
429
- dropout_p=self.attention_dropout if self.training else 0.0,
430
- is_causal=False, # hard coded
431
- )
432
-
433
- attn_output = attn_output.transpose(1, 2).contiguous()
434
- attn_output = attn_output.view(bsz, q_len, self.hidden_size)
435
-
436
- attn_output = self.o_proj(attn_output)
437
-
438
- return attn_output, None, past_key_value
439
-
440
-
441
- class DreamDecoderLayer(nn.Module):
442
- def __init__(self, config: DreamConfig, layer_idx: int):
443
- super().__init__()
444
- self.hidden_size = config.hidden_size
445
-
446
- if config.sliding_window and config._attn_implementation != "flash_attention_2":
447
- logger.warning_once(
448
- f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
449
- "unexpected results may be encountered."
450
- )
451
-
452
- # self.self_attn = Dream_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
453
- self.self_attn = DreamSdpaAttention(config, layer_idx)
454
-
455
- self.mlp = DreamMLP(config)
456
- self.input_layernorm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
457
- self.post_attention_layernorm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
458
-
459
- def forward(
460
- self,
461
- hidden_states: torch.Tensor,
462
- attention_mask: Optional[torch.Tensor] = None,
463
- position_ids: Optional[torch.LongTensor] = None,
464
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
465
- output_attentions: Optional[bool] = False,
466
- use_cache: Optional[bool] = False,
467
- cache_position: Optional[torch.LongTensor] = None,
468
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
469
- **kwargs,
470
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
471
- """
472
- Args:
473
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
474
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
475
- `(batch, sequence_length)` where padding elements are indicated by 0.
476
- output_attentions (`bool`, *optional*):
477
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
478
- returned tensors for more detail.
479
- use_cache (`bool`, *optional*):
480
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
481
- (see `past_key_values`).
482
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
483
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
484
- Indices depicting the position of the input sequence tokens in the sequence.
485
- position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
486
- Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
487
- with `head_dim` being the embedding dimension of each attention head.
488
- kwargs (`dict`, *optional*):
489
- Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
490
- into the model
491
- """
492
-
493
- residual = hidden_states
494
-
495
- hidden_states = self.input_layernorm(hidden_states)
496
-
497
- # Self Attention
498
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
499
- hidden_states=hidden_states,
500
- attention_mask=attention_mask,
501
- position_ids=position_ids,
502
- past_key_value=past_key_value,
503
- output_attentions=output_attentions,
504
- use_cache=use_cache,
505
- cache_position=cache_position,
506
- position_embeddings=position_embeddings,
507
- )
508
- hidden_states = residual + hidden_states
509
-
510
- # Fully Connected
511
- residual = hidden_states
512
- hidden_states = self.post_attention_layernorm(hidden_states)
513
- hidden_states = self.mlp(hidden_states)
514
- hidden_states = residual + hidden_states
515
-
516
- outputs = (hidden_states,)
517
-
518
- if output_attentions:
519
- outputs += (self_attn_weights,)
520
-
521
- if use_cache:
522
- outputs += (present_key_value,)
523
-
524
- return outputs
525
-
526
- class DreamPreTrainedModel(PreTrainedModel):
527
- config_class = DreamConfig
528
- base_model_prefix = "model"
529
- supports_gradient_checkpointing = True
530
- _no_split_modules = ["DreamDecoderLayer"]
531
- _skip_keys_device_placement = "past_key_values"
532
- _supports_flash_attn_2 = True
533
- _supports_sdpa = True
534
- _supports_cache_class = True
535
- _supports_quantized_cache = True
536
- _supports_static_cache = True
537
-
538
- def _init_weights(self, module):
539
- std = self.config.initializer_range
540
- if isinstance(module, nn.Linear):
541
- module.weight.data.normal_(mean=0.0, std=std)
542
- if module.bias is not None:
543
- module.bias.data.zero_()
544
- elif isinstance(module, nn.Embedding):
545
- module.weight.data.normal_(mean=0.0, std=std)
546
- if module.padding_idx is not None:
547
- module.weight.data[module.padding_idx].zero_()
548
-
549
- @classmethod
550
- def from_pretrained(
551
- cls,
552
- pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
553
- *model_args,
554
- config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
555
- cache_dir: Optional[Union[str, os.PathLike]] = None,
556
- ignore_mismatched_sizes: bool = False,
557
- force_download: bool = False,
558
- local_files_only: bool = False,
559
- token: Optional[Union[str, bool]] = None,
560
- revision: str = "main",
561
- use_safetensors: Optional[bool] = None,
562
- weights_only: bool = True,
563
- **kwargs,
564
- ):
565
- _model = super().from_pretrained(
566
- pretrained_model_name_or_path,
567
- *model_args,
568
- config=config,
569
- cache_dir=cache_dir,
570
- ignore_mismatched_sizes=ignore_mismatched_sizes,
571
- force_download=force_download,
572
- local_files_only=local_files_only,
573
- token=token,
574
- revision=revision,
575
- use_safetensors=use_safetensors,
576
- weights_only=weights_only,
577
- **kwargs,
578
- )
579
- # NOTE(Lin): we need to override the generation config
580
- # because the generation config loaded in `from_pretrained`
581
- # does not include all the attributes of DreamGenerationConfig
582
- resume_download = kwargs.get("resume_download", None)
583
- proxies = kwargs.get("proxies", None)
584
- subfolder = kwargs.get("subfolder", "")
585
- from_auto_class = kwargs.get("_from_auto", False)
586
- from_pipeline = kwargs.get("_from_pipeline", None)
587
- _model.generation_config = DreamGenerationConfig.from_pretrained(
588
- pretrained_model_name_or_path,
589
- cache_dir=cache_dir,
590
- force_download=force_download,
591
- resume_download=resume_download,
592
- proxies=proxies,
593
- local_files_only=local_files_only,
594
- token=token,
595
- revision=revision,
596
- subfolder=subfolder,
597
- _from_auto=from_auto_class,
598
- _from_pipeline=from_pipeline,
599
- )
600
- return _model
601
-
602
- class DreamBaseModel(DreamPreTrainedModel):
603
- """
604
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DreamDecoderLayer`]
605
-
606
- Args:
607
- config: DreamConfig
608
- """
609
-
610
- def __init__(self, config: DreamConfig):
611
- super().__init__(config)
612
- self.padding_idx = config.pad_token_id
613
- self.vocab_size = config.vocab_size
614
-
615
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
616
- self.layers = nn.ModuleList(
617
- [DreamDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
618
- )
619
- self._attn_implementation = config._attn_implementation
620
- self.norm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
621
- self.rotary_emb = DreamRotaryEmbedding(config=config)
622
-
623
- self.gradient_checkpointing = False
624
- # Initialize weights and apply final processing
625
- self.post_init()
626
-
627
- def get_input_embeddings(self):
628
- return self.embed_tokens
629
-
630
- def set_input_embeddings(self, value):
631
- self.embed_tokens = value
632
-
633
- def forward(
634
- self,
635
- input_ids: torch.LongTensor = None,
636
- attention_mask: Optional[torch.Tensor] = None,
637
- position_ids: Optional[torch.LongTensor] = None,
638
- past_key_values: Optional[List[torch.FloatTensor]] = None,
639
- inputs_embeds: Optional[torch.FloatTensor] = None,
640
- use_cache: Optional[bool] = None,
641
- output_attentions: Optional[bool] = None,
642
- output_hidden_states: Optional[bool] = None,
643
- return_dict: Optional[bool] = None,
644
- cache_position: Optional[torch.LongTensor] = None,
645
- ) -> Union[Tuple, BaseModelOutput]:
646
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
647
- output_hidden_states = (
648
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
649
- )
650
- use_cache = use_cache if use_cache is not None else self.config.use_cache
651
-
652
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
653
-
654
- if (input_ids is None) ^ (inputs_embeds is not None):
655
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
656
-
657
- if self.gradient_checkpointing and self.training:
658
- if use_cache:
659
- logger.warning_once(
660
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
661
- )
662
- use_cache = False
663
-
664
- if inputs_embeds is None:
665
- inputs_embeds = self.embed_tokens(input_ids)
666
-
667
- if use_cache and past_key_values is None:
668
- past_key_values = DynamicCache()
669
-
670
- if cache_position is None:
671
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
672
- cache_position = torch.arange(
673
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
674
- )
675
-
676
- if position_ids is None:
677
- position_ids = cache_position.unsqueeze(0)
678
-
679
- hidden_states = inputs_embeds
680
-
681
- # create position embeddings to be shared across the decoder layers
682
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
683
-
684
- # decoder layers
685
- all_hidden_states = () if output_hidden_states else None
686
- all_self_attns = () if output_attentions else None
687
-
688
- for decoder_layer in self.layers:
689
- if output_hidden_states:
690
- all_hidden_states += (hidden_states,)
691
-
692
- if self.gradient_checkpointing and self.training:
693
- layer_outputs = self._gradient_checkpointing_func(
694
- decoder_layer.__call__,
695
- hidden_states,
696
- attention_mask,
697
- position_ids,
698
- past_key_values,
699
- output_attentions,
700
- use_cache,
701
- cache_position,
702
- position_embeddings,
703
- )
704
- else:
705
- layer_outputs = decoder_layer(
706
- hidden_states,
707
- attention_mask=attention_mask,
708
- position_ids=position_ids,
709
- past_key_value=past_key_values,
710
- output_attentions=output_attentions,
711
- use_cache=use_cache,
712
- cache_position=cache_position,
713
- position_embeddings=position_embeddings,
714
- )
715
-
716
- hidden_states = layer_outputs[0]
717
-
718
- if output_attentions:
719
- all_self_attns += (layer_outputs[1],)
720
-
721
- hidden_states = self.norm(hidden_states)
722
-
723
- # add hidden states from the last decoder layer
724
- if output_hidden_states:
725
- all_hidden_states += (hidden_states,)
726
-
727
- if not return_dict:
728
- return tuple(v for v in [hidden_states, all_hidden_states, all_self_attns] if v is not None)
729
- return BaseModelOutput(
730
- last_hidden_state=hidden_states,
731
- hidden_states=all_hidden_states,
732
- attentions=all_self_attns,
733
- )
734
-
735
-
736
- class DreamModel(DreamGenerationMixin, DreamPreTrainedModel):
737
- _tied_weights_keys = ["lm_head.weight"]
738
-
739
- def __init__(self, config):
740
- super().__init__(config)
741
- self.model = DreamBaseModel(config)
742
- self.vocab_size = config.vocab_size
743
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
744
-
745
- # Initialize weights and apply final processing
746
- self.post_init()
747
-
748
- def reset_rope_parameters(self):
749
- self.model.rotary_emb.reset_parameters()
750
- for layer in self.model.layers:
751
- layer.self_attn.rotary_emb.reset_parameters()
752
-
753
- def get_input_embeddings(self):
754
- return self.model.embed_tokens
755
-
756
- def set_input_embeddings(self, value):
757
- self.model.embed_tokens = value
758
-
759
- def get_output_embeddings(self):
760
- return self.lm_head
761
-
762
- def set_output_embeddings(self, new_embeddings):
763
- self.lm_head = new_embeddings
764
-
765
- def set_decoder(self, decoder):
766
- self.model = decoder
767
-
768
- def get_decoder(self):
769
- return self.model
770
-
771
- def forward(
772
- self,
773
- input_ids: torch.LongTensor = None,
774
- attention_mask: Optional[torch.Tensor] = None,
775
- position_ids: Optional[torch.LongTensor] = None,
776
- past_key_values: Optional[List[torch.FloatTensor]] = None,
777
- inputs_embeds: Optional[torch.FloatTensor] = None,
778
- labels: Optional[torch.LongTensor] = None,
779
- use_cache: Optional[bool] = None,
780
- output_attentions: Optional[bool] = None,
781
- output_hidden_states: Optional[bool] = None,
782
- return_dict: Optional[bool] = None,
783
- cache_position: Optional[torch.LongTensor] = None,
784
- num_logits_to_keep: int = 0,
785
- **loss_kwargs,
786
- ) -> Union[Tuple, MaskedLMOutput]:
787
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
788
- output_hidden_states = (
789
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
790
- )
791
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
792
-
793
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
794
- outputs = self.model(
795
- input_ids=input_ids,
796
- attention_mask=attention_mask,
797
- position_ids=position_ids,
798
- past_key_values=past_key_values,
799
- inputs_embeds=inputs_embeds,
800
- use_cache=use_cache,
801
- output_attentions=output_attentions,
802
- output_hidden_states=output_hidden_states,
803
- return_dict=return_dict,
804
- cache_position=cache_position,
805
- )
806
-
807
- hidden_states = outputs[0]
808
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
809
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
810
-
811
- loss = None
812
- if labels is not None:
813
- loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
814
-
815
- if not return_dict:
816
- output = (logits,) + outputs[1:]
817
- return (loss,) + output if loss is not None else output
818
-
819
- return MaskedLMOutput(
820
- loss=loss,
821
- logits=logits,
822
- hidden_states=outputs.hidden_states,
823
- attentions=outputs.attentions,
824
- )