QscQ commited on
Commit
55a7282
·
1 Parent(s): 4402802
Files changed (2) hide show
  1. config.json +80 -80
  2. modeling_minimax_text_01.py +3 -3
config.json CHANGED
@@ -4,86 +4,86 @@
4
  ],
5
  "attention_dropout": 0.0,
6
  "layer_types": [
7
- 0,
8
- 0,
9
- 0,
10
- 0,
11
- 0,
12
- 0,
13
- 0,
14
- 1,
15
- 0,
16
- 0,
17
- 0,
18
- 0,
19
- 0,
20
- 0,
21
- 0,
22
- 1,
23
- 0,
24
- 0,
25
- 0,
26
- 0,
27
- 0,
28
- 0,
29
- 0,
30
- 1,
31
- 0,
32
- 0,
33
- 0,
34
- 0,
35
- 0,
36
- 0,
37
- 0,
38
- 1,
39
- 0,
40
- 0,
41
- 0,
42
- 0,
43
- 0,
44
- 0,
45
- 0,
46
- 1,
47
- 0,
48
- 0,
49
- 0,
50
- 0,
51
- 0,
52
- 0,
53
- 0,
54
- 1,
55
- 0,
56
- 0,
57
- 0,
58
- 0,
59
- 0,
60
- 0,
61
- 0,
62
- 1,
63
- 0,
64
- 0,
65
- 0,
66
- 0,
67
- 0,
68
- 0,
69
- 0,
70
- 1,
71
- 0,
72
- 0,
73
- 0,
74
- 0,
75
- 0,
76
- 0,
77
- 0,
78
- 1,
79
- 0,
80
- 0,
81
- 0,
82
- 0,
83
- 0,
84
- 0,
85
- 0,
86
- 1
87
  ],
88
  "auto_map": {
89
  "AutoConfig": "configuration_minimax_text_01.MiniMaxText01Config",
 
4
  ],
5
  "attention_dropout": 0.0,
6
  "layer_types": [
7
+ "linear_attention",
8
+ "linear_attention",
9
+ "linear_attention",
10
+ "linear_attention",
11
+ "linear_attention",
12
+ "linear_attention",
13
+ "linear_attention",
14
+ "full_attention",
15
+ "linear_attention",
16
+ "linear_attention",
17
+ "linear_attention",
18
+ "linear_attention",
19
+ "linear_attention",
20
+ "linear_attention",
21
+ "linear_attention",
22
+ "full_attention",
23
+ "linear_attention",
24
+ "linear_attention",
25
+ "linear_attention",
26
+ "linear_attention",
27
+ "linear_attention",
28
+ "linear_attention",
29
+ "linear_attention",
30
+ "full_attention",
31
+ "linear_attention",
32
+ "linear_attention",
33
+ "linear_attention",
34
+ "linear_attention",
35
+ "linear_attention",
36
+ "linear_attention",
37
+ "linear_attention",
38
+ "full_attention",
39
+ "linear_attention",
40
+ "linear_attention",
41
+ "linear_attention",
42
+ "linear_attention",
43
+ "linear_attention",
44
+ "linear_attention",
45
+ "linear_attention",
46
+ "full_attention",
47
+ "linear_attention",
48
+ "linear_attention",
49
+ "linear_attention",
50
+ "linear_attention",
51
+ "linear_attention",
52
+ "linear_attention",
53
+ "linear_attention",
54
+ "full_attention",
55
+ "linear_attention",
56
+ "linear_attention",
57
+ "linear_attention",
58
+ "linear_attention",
59
+ "linear_attention",
60
+ "linear_attention",
61
+ "linear_attention",
62
+ "full_attention",
63
+ "linear_attention",
64
+ "linear_attention",
65
+ "linear_attention",
66
+ "linear_attention",
67
+ "linear_attention",
68
+ "linear_attention",
69
+ "linear_attention",
70
+ "full_attention",
71
+ "linear_attention",
72
+ "linear_attention",
73
+ "linear_attention",
74
+ "linear_attention",
75
+ "linear_attention",
76
+ "linear_attention",
77
+ "linear_attention",
78
+ "full_attention",
79
+ "linear_attention",
80
+ "linear_attention",
81
+ "linear_attention",
82
+ "linear_attention",
83
+ "linear_attention",
84
+ "linear_attention",
85
+ "linear_attention",
86
+ "full_attention"
87
  ],
88
  "auto_map": {
89
  "AutoConfig": "configuration_minimax_text_01.MiniMaxText01Config",
modeling_minimax_text_01.py CHANGED
@@ -1200,13 +1200,13 @@ class MiniMaxText01Model(MiniMaxText01PreTrainedModel):
1200
  self.vocab_size = config.vocab_size
1201
 
1202
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1203
- self.attn_type_list = config.attn_type_list
1204
  config_copy = copy.deepcopy(config)
1205
 
1206
  self.layers = nn.ModuleList([])
1207
  for i in range(config.num_hidden_layers):
1208
  _config = copy.deepcopy(config)
1209
- if self.attn_type_list[i] == 0:
1210
  _config._attn_implementation = 'linear_attention'
1211
  _config.attention_type = 0
1212
  else:
@@ -1305,7 +1305,7 @@ class MiniMaxText01Model(MiniMaxText01PreTrainedModel):
1305
  seq_length_with_past = seq_length
1306
  if past_key_values is not None:
1307
  for idx in range(len(past_key_values)):
1308
- if self.attn_type_list[idx] == 1:
1309
  past_key_values_length = past_key_values[idx][0].shape[-3]
1310
  seq_length_with_past = seq_length_with_past + past_key_values_length
1311
  break
 
1200
  self.vocab_size = config.vocab_size
1201
 
1202
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1203
+ self.layer_types = config.layer_types
1204
  config_copy = copy.deepcopy(config)
1205
 
1206
  self.layers = nn.ModuleList([])
1207
  for i in range(config.num_hidden_layers):
1208
  _config = copy.deepcopy(config)
1209
+ if self.layer_types[i] == "linear_attention":
1210
  _config._attn_implementation = 'linear_attention'
1211
  _config.attention_type = 0
1212
  else:
 
1305
  seq_length_with_past = seq_length
1306
  if past_key_values is not None:
1307
  for idx in range(len(past_key_values)):
1308
+ if self.layer_types[idx] == "full_attention":
1309
  past_key_values_length = past_key_values[idx][0].shape[-3]
1310
  seq_length_with_past = seq_length_with_past + past_key_values_length
1311
  break