manueldeprada HF Staff commited on
Commit
0cb565f
·
verified ·
1 Parent(s): 9e6802a

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
.ruff_cache/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Automatically created by ruff.
2
+ *
.ruff_cache/0.12.8/7010951691598163845 ADDED
Binary file (149 Bytes). View file
 
.ruff_cache/CACHEDIR.TAG ADDED
@@ -0,0 +1 @@
 
 
1
+ Signature: 8a477f597d28d172789f06886806bc55
README.md ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+
3
+ library_name: transformers
4
+ tags:
5
+ - custom_generate
6
+ ---
7
+
8
+ ## Description
9
+
10
+ Constrained Beam Search extends standard beam search by allowing you to enforce lexical or phrasal constraints in the generated output. This is useful when you know certain words or phrases must appear (e.g., translation dictionaries, product names, slot values), or when multiple outputs are equally probable but only some are desirable for your use case.
11
+
12
+ Unlike ordinary beam search, constrained beam search steers generation to include required subsequences somewhere in the final output while balancing fluency.
13
+
14
+ ---
15
+
16
+ ## Why it's difficult
17
+
18
+ Beam search generates token-by-token and scores candidates locally. Forcing a phrase like "is fast" to appear somewhere requires the search to plan several steps ahead and decide when to insert the constrained tokens without breaking fluency. The problem becomes more complex with multiple constraints, optional alternatives, or ordering requirements.
19
+
20
+ Constrained beam search solves this by:
21
+ - Injecting constraint-progressing tokens among regular high-probability candidates
22
+ - Grouping beams into banks by how much of the constraints they satisfied
23
+ - Selecting beams round-robin across banks to balance fluency and constraint satisfaction
24
+
25
+ ---
26
+
27
+ ## Base model
28
+
29
+ * [Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B)
30
+
31
+ ---
32
+
33
+ ## Model compatibility
34
+
35
+ - Encoder-decoder and decoder-only transformer models
36
+
37
+ ---
38
+
39
+ ## Additional Arguments
40
+
41
+ - `constraints` (list[Constraint]): Advanced constraints, e.g., `PhrasalConstraint`, `DisjunctiveConstraint`
42
+ - `force_words_ids` (list[list[int]] | list[list[list[int]]]): Simple way to specify words/phrases or disjunctive sets
43
+ - `num_beams` (int): Beam width
44
+ - Other standard beam args: `length_penalty`, `early_stopping`, `num_return_sequences`, `max_length`
45
+
46
+ Notes:
47
+ - Constrained decoding is incompatible with sampling: set `do_sample=False`
48
+ - Tokenize constraints without adding special tokens
49
+
50
+ ---
51
+
52
+ ## Example 1: Forcing a word (formal German translation)
53
+
54
+ ```python
55
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
56
+
57
+ tokenizer = AutoTokenizer.from_pretrained("t5-base")
58
+ model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
59
+
60
+ encoder_input_str = "translate English to German: How old are you?"
61
+ input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
62
+
63
+ force_words = ["Sie"]
64
+ force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids
65
+
66
+ outputs = model.generate(
67
+ input_ids,
68
+ custom_generate="transformers-community/constrained-beam-search",
69
+ force_words_ids=force_words_ids,
70
+ num_beams=5,
71
+ num_return_sequences=1,
72
+ no_repeat_ngram_size=1,
73
+ remove_invalid_values=True,
74
+ trust_remote_code=True,
75
+ )
76
+
77
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
78
+ ```
79
+
80
+ Expected to contain the forced word: `Wie alt sind Sie?`
81
+
82
+ ---
83
+
84
+ ## Example 2: Disjunctive constraints (choose any of several forms)
85
+
86
+ ```python
87
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
88
+
89
+ model = GPT2LMHeadModel.from_pretrained("gpt2")
90
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
91
+
92
+ force_word = "scared"
93
+ force_flexible = ["scream", "screams", "screaming", "screamed"]
94
+
95
+ force_words_ids = [
96
+ tokenizer([force_word], add_prefix_space=True, add_special_tokens=False).input_ids,
97
+ tokenizer(force_flexible, add_prefix_space=True, add_special_tokens=False).input_ids,
98
+ ]
99
+
100
+ starting_text = ["The soldiers", "The child"]
101
+ input_ids = tokenizer(starting_text, return_tensors="pt").input_ids
102
+
103
+ outputs = model.generate(
104
+ input_ids,
105
+ custom_generate="transformers-community/constrained-beam-search",
106
+ force_words_ids=force_words_ids,
107
+ num_beams=10,
108
+ num_return_sequences=1,
109
+ no_repeat_ngram_size=1,
110
+ remove_invalid_values=True,
111
+ trust_remote_code=True,
112
+ )
113
+
114
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
115
+ print(tokenizer.decode(outputs[1], skip_special_tokens=True))
116
+ ```
117
+
118
+ Outputs will include the mandatory word and at least one from the flexible set.
config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "eos_token_id": 151645,
9
+ "head_dim": 128,
10
+ "hidden_act": "silu",
11
+ "hidden_size": 1024,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "max_position_embeddings": 40960,
15
+ "max_window_layers": 28,
16
+ "model_type": "qwen3",
17
+ "num_attention_heads": 16,
18
+ "num_hidden_layers": 28,
19
+ "num_key_value_heads": 8,
20
+ "rms_norm_eps": 1e-06,
21
+ "rope_scaling": null,
22
+ "rope_theta": 1000000,
23
+ "sliding_window": null,
24
+ "tie_word_embeddings": true,
25
+ "torch_dtype": "bfloat16",
26
+ "transformers_version": "4.56.0",
27
+ "use_cache": true,
28
+ "use_sliding_window": false,
29
+ "vocab_size": 151936
30
+ }
custom_generate/beam_constraints.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Optional
3
+
4
+
5
+ class Constraint(ABC):
6
+ r"""Abstract base class for all constraints that can be applied during generation.
7
+ It must define how the constraint can be satisfied.
8
+
9
+ All classes that inherit Constraint must follow the requirement that
10
+
11
+ ```py
12
+ completed = False
13
+ while not completed:
14
+ _, completed = constraint.update(constraint.advance())
15
+ ```
16
+
17
+ will always terminate (halt).
18
+ """
19
+
20
+ def __init__(self):
21
+ # test for the above condition
22
+ self.test()
23
+
24
+ def test(self):
25
+ """
26
+ Tests whether this constraint has been properly defined.
27
+ """
28
+ counter = 0
29
+ completed = False
30
+ while not completed:
31
+ if counter == 1:
32
+ self.reset()
33
+ advance = self.advance()
34
+ if not self.does_advance(advance):
35
+ raise Exception(
36
+ "Custom Constraint is not defined correctly. self.does_advance(self.advance()) must be true."
37
+ )
38
+
39
+ stepped, completed, reset = self.update(advance)
40
+ counter += 1
41
+
42
+ if counter > 10000:
43
+ raise Exception("update() does not fulfill the constraint.")
44
+
45
+ if self.remaining() != 0:
46
+ raise Exception("Custom Constraint is not defined correctly.")
47
+
48
+ @abstractmethod
49
+ def advance(self):
50
+ """
51
+ When called, returns the token(s) that would take this constraint one step closer to being fulfilled.
52
+
53
+ Return:
54
+ token_ids (Union[int, list[int], None]):
55
+ - A single token ID (int) that advances the constraint, or
56
+ - A list of token IDs that could advance the constraint
57
+ - None if the constraint is completed or cannot be advanced
58
+ """
59
+ raise NotImplementedError(
60
+ f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
61
+ )
62
+
63
+ @abstractmethod
64
+ def does_advance(self, token_id: int):
65
+ """
66
+ Reads in a token and returns whether it creates progress.
67
+ """
68
+ raise NotImplementedError(
69
+ f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
70
+ )
71
+
72
+ @abstractmethod
73
+ def update(self, token_id: int):
74
+ """
75
+ Reads in a token and returns booleans that indicate the progress made by it. This function will update the
76
+ state of this object unlikes `does_advance(self, token_id: int)`.
77
+
78
+ This isn't to test whether a certain token will advance the progress; it's to update its state as if it has
79
+ been generated. This becomes important if token_id != desired token (refer to else statement in
80
+ PhrasalConstraint)
81
+
82
+ Args:
83
+ token_id(`int`):
84
+ The id of a newly generated token in the beam search.
85
+ Return:
86
+ stepped(`bool`):
87
+ Whether this constraint has become one step closer to being fulfuilled.
88
+ completed(`bool`):
89
+ Whether this constraint has been completely fulfilled by this token being generated.
90
+ reset (`bool`):
91
+ Whether this constraint has reset its progress by this token being generated.
92
+ """
93
+ raise NotImplementedError(
94
+ f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
95
+ )
96
+
97
+ @abstractmethod
98
+ def reset(self):
99
+ """
100
+ Resets the state of this constraint to its initialization. We would call this in cases where the fulfillment of
101
+ a constraint is abrupted by an unwanted token.
102
+ """
103
+ raise NotImplementedError(
104
+ f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
105
+ )
106
+
107
+ @abstractmethod
108
+ def remaining(self):
109
+ """
110
+ Returns the number of remaining steps of `advance()` in order to complete this constraint.
111
+ """
112
+ raise NotImplementedError(
113
+ f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
114
+ )
115
+
116
+ @abstractmethod
117
+ def copy(self, stateful=False):
118
+ """
119
+ Creates a new instance of this constraint.
120
+
121
+ Args:
122
+ stateful(`bool`): Whether to not only copy the constraint for new instance, but also its state.
123
+
124
+ Return:
125
+ constraint(`Constraint`): The same constraint as the one being called from.
126
+ """
127
+ raise NotImplementedError(
128
+ f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
129
+ )
130
+
131
+
132
+ class PhrasalConstraint(Constraint):
133
+ r"""
134
+ [`Constraint`] enforcing that an ordered sequence of tokens is included in the output.
135
+
136
+ Args:
137
+ token_ids (`list[int]`):
138
+ The id of the token that must be generated by the output.
139
+ """
140
+
141
+ def __init__(self, token_ids: list[int]):
142
+ super(Constraint, self).__init__()
143
+
144
+ if not isinstance(token_ids, list) or len(token_ids) == 0:
145
+ raise ValueError(f"`token_ids` has to be a non-empty list, but is {token_ids}.")
146
+ if any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids):
147
+ raise ValueError(f"Each list in `token_ids` has to be a list of positive integers, but is {token_ids}.")
148
+
149
+ self.token_ids = token_ids
150
+
151
+ self.seqlen = len(self.token_ids)
152
+ self.fulfilled_idx = -1 # the index of the currently fulfilled step
153
+ self.completed = False
154
+
155
+ def advance(self):
156
+ if self.completed:
157
+ return None
158
+ return self.token_ids[self.fulfilled_idx + 1]
159
+
160
+ def does_advance(self, token_id: int):
161
+ if not isinstance(token_id, int):
162
+ raise TypeError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}")
163
+
164
+ if self.completed:
165
+ return False
166
+
167
+ return token_id == self.token_ids[self.fulfilled_idx + 1]
168
+
169
+ def update(self, token_id: int):
170
+ if not isinstance(token_id, int):
171
+ raise TypeError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}")
172
+
173
+ stepped = False
174
+ completed = False
175
+ reset = False
176
+
177
+ if self.does_advance(token_id):
178
+ self.fulfilled_idx += 1
179
+ stepped = True
180
+ if self.fulfilled_idx == (self.seqlen - 1):
181
+ completed = True
182
+ self.completed = completed
183
+ else:
184
+ # failed to make progress.
185
+ reset = True
186
+ self.reset()
187
+ return stepped, completed, reset
188
+
189
+ def reset(self):
190
+ self.completed = False
191
+ self.fulfilled_idx = 0
192
+
193
+ def remaining(self):
194
+ return self.seqlen - (self.fulfilled_idx + 1)
195
+
196
+ def copy(self, stateful=False):
197
+ new_constraint = PhrasalConstraint(self.token_ids)
198
+
199
+ if stateful:
200
+ new_constraint.seq_len = self.seqlen
201
+ new_constraint.fulfilled_idx = self.fulfilled_idx
202
+ new_constraint.completed = self.completed
203
+
204
+ return new_constraint
205
+
206
+
207
+ class DisjunctiveTrie:
208
+ def __init__(self, nested_token_ids: list[list[int]], no_subsets=True):
209
+ r"""
210
+ A helper class that builds a trie with the words represented in `nested_token_ids`.
211
+ """
212
+ self.max_height = max([len(one) for one in nested_token_ids])
213
+
214
+ root = {}
215
+ for token_ids in nested_token_ids:
216
+ level = root
217
+ for tidx, token_id in enumerate(token_ids):
218
+ if token_id not in level:
219
+ level[token_id] = {}
220
+
221
+ level = level[token_id]
222
+
223
+ if no_subsets and self.has_subsets(root, nested_token_ids):
224
+ raise ValueError(
225
+ "Each list in `nested_token_ids` can't be a complete subset of another list, but is"
226
+ f" {nested_token_ids}."
227
+ )
228
+
229
+ self.trie = root
230
+
231
+ def next_tokens(self, current_seq):
232
+ """
233
+ The next possible tokens that will progress the trie, given the current sequence of tokens in `current_seq`.
234
+ """
235
+ start = self.trie
236
+
237
+ for current_token in current_seq:
238
+ start = start[current_token]
239
+
240
+ next_tokens = list(start.keys())
241
+
242
+ return next_tokens
243
+
244
+ def reached_leaf(self, current_seq):
245
+ next_tokens = self.next_tokens(current_seq)
246
+
247
+ return len(next_tokens) == 0
248
+
249
+ def count_leaves(self, root):
250
+ next_nodes = list(root.values())
251
+ if len(next_nodes) == 0:
252
+ return 1
253
+ else:
254
+ return sum([self.count_leaves(nn) for nn in next_nodes])
255
+
256
+ def has_subsets(self, trie, nested_token_ids):
257
+ """
258
+ Returns whether # of leaves == # of words. Otherwise some word is a subset of another.
259
+ """
260
+ leaf_count = self.count_leaves(trie)
261
+ return len(nested_token_ids) != leaf_count
262
+
263
+
264
+ class DisjunctiveConstraint(Constraint):
265
+ r"""
266
+ A special [`Constraint`] that is fulfilled by fulfilling just one of several constraints.
267
+
268
+ Args:
269
+ nested_token_ids (`list[list[int]]`):
270
+ A list of words, where each word is a list of ids. This constraint is fulfilled by generating just one from
271
+ the list of words.
272
+ """
273
+
274
+ def __init__(self, nested_token_ids: list[list[int]]):
275
+ super(Constraint, self).__init__()
276
+
277
+ if not isinstance(nested_token_ids, list) or len(nested_token_ids) == 0:
278
+ raise ValueError(f"`nested_token_ids` has to be a non-empty list, but is {nested_token_ids}.")
279
+ if any(not isinstance(token_ids, list) for token_ids in nested_token_ids):
280
+ raise ValueError(f"`nested_token_ids` has to be a list of lists, but is {nested_token_ids}.")
281
+ if any(
282
+ any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
283
+ for token_ids in nested_token_ids
284
+ ):
285
+ raise ValueError(
286
+ f"Each list in `nested_token_ids` has to be a list of positive integers, but is {nested_token_ids}."
287
+ )
288
+
289
+ self.trie = DisjunctiveTrie(nested_token_ids)
290
+ self.token_ids = nested_token_ids
291
+
292
+ self.seqlen = self.trie.max_height
293
+ self.current_seq = []
294
+ self.completed = False
295
+
296
+ def advance(self):
297
+ token_list = self.trie.next_tokens(self.current_seq)
298
+
299
+ if len(token_list) == 0:
300
+ return None
301
+ else:
302
+ return token_list
303
+
304
+ def does_advance(self, token_id: int):
305
+ if not isinstance(token_id, int):
306
+ raise TypeError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")
307
+
308
+ next_tokens = self.trie.next_tokens(self.current_seq)
309
+
310
+ return token_id in next_tokens
311
+
312
+ def update(self, token_id: int):
313
+ if not isinstance(token_id, int):
314
+ raise TypeError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")
315
+
316
+ stepped = False
317
+ completed = False
318
+ reset = False
319
+
320
+ if self.does_advance(token_id):
321
+ self.current_seq.append(token_id)
322
+ stepped = True
323
+ else:
324
+ reset = True
325
+ self.reset()
326
+
327
+ completed = self.trie.reached_leaf(self.current_seq)
328
+ self.completed = completed
329
+
330
+ return stepped, completed, reset
331
+
332
+ def reset(self):
333
+ self.completed = False
334
+ self.current_seq = []
335
+
336
+ def remaining(self):
337
+ if self.completed:
338
+ # since this can be completed without reaching max height
339
+ return 0
340
+ else:
341
+ return self.seqlen - len(self.current_seq)
342
+
343
+ def copy(self, stateful=False):
344
+ new_constraint = DisjunctiveConstraint(self.token_ids)
345
+
346
+ if stateful:
347
+ new_constraint.seq_len = self.seqlen
348
+ new_constraint.current_seq = self.current_seq
349
+ new_constraint.completed = self.completed
350
+
351
+ return new_constraint
352
+
353
+
354
+ class ConstraintListState:
355
+ r"""
356
+ A class for beam scorers to track its progress through a list of constraints.
357
+
358
+ Args:
359
+ constraints (`list[Constraint]`):
360
+ A list of [`Constraint`] objects that must be fulfilled by the beam scorer.
361
+ """
362
+
363
+ def __init__(self, constraints: list[Constraint]):
364
+ self.constraints = constraints
365
+
366
+ # max # of steps required to fulfill a given constraint
367
+ self.max_seqlen = max([c.seqlen for c in constraints])
368
+ self.n_constraints = len(constraints)
369
+ self.completed = False
370
+
371
+ self.init_state()
372
+
373
+ def init_state(self):
374
+ self.complete_constraints = []
375
+ self.inprogress_constraint = None
376
+ self.pending_constraints = [constraint.copy(stateful=False) for constraint in self.constraints]
377
+
378
+ def get_bank(self):
379
+ add = 0
380
+ if self.inprogress_constraint:
381
+ # extra points for having a constraint mid-fulfilled
382
+ add += self.max_seqlen - self.inprogress_constraint.remaining()
383
+
384
+ return (len(self.complete_constraints) * self.max_seqlen) + add
385
+
386
+ def advance(self):
387
+ """The list of tokens to generate such that we can make progress.
388
+ By "list" we don't mean the list of token that will fully fulfill a constraint.
389
+
390
+ Given constraints `c_i = {t_ij | j == # of tokens}`, If we're not in the middle of progressing through a
391
+ specific constraint `c_i`, we return:
392
+
393
+ `[t_k1 for k in indices of unfulfilled constraints]`
394
+
395
+ If we are in the middle of a constraint, then we return:
396
+ `[t_ij]`, where `i` is the index of the inprogress constraint, `j` is the next step for the constraint.
397
+
398
+ Though we don't care which constraint is fulfilled first, if we are in the progress of fulfilling a constraint,
399
+ that's the only one we'll return.
400
+ """
401
+ token_list = []
402
+ if self.inprogress_constraint is None:
403
+ for constraint in self.pending_constraints: # "pending" == "unfulfilled yet"
404
+ advance = constraint.advance()
405
+ if isinstance(advance, int):
406
+ token_list.append(advance)
407
+ elif isinstance(advance, list):
408
+ token_list.extend(advance)
409
+ else:
410
+ advance = self.inprogress_constraint.advance()
411
+ if isinstance(advance, int):
412
+ token_list.append(advance)
413
+ elif isinstance(advance, list):
414
+ token_list.extend(advance)
415
+
416
+ if len(token_list) == 0:
417
+ return None
418
+ else:
419
+ return token_list
420
+
421
+ def reset(self, token_ids: Optional[list[int]]):
422
+ """
423
+ token_ids: the tokens generated thus far to reset the state of the progress through constraints.
424
+ """
425
+ self.init_state()
426
+
427
+ if token_ids is not None:
428
+ for token in token_ids:
429
+ # completes or steps **one** constraint
430
+ complete, stepped = self.add(token)
431
+
432
+ # the entire list of constraints are fulfilled
433
+ if self.completed:
434
+ break
435
+
436
+ def add(self, token_id: int):
437
+ if not isinstance(token_id, int):
438
+ raise TypeError(f"`token_id` should be an `int`, but is `{token_id}`.")
439
+
440
+ complete, stepped = False, False
441
+
442
+ if self.completed:
443
+ complete = True
444
+ stepped = False
445
+ return complete, stepped
446
+
447
+ if self.inprogress_constraint is not None:
448
+ # In the middle of fulfilling a constraint. If the `token_id` *does* makes an incremental progress to current
449
+ # job, simply update the state
450
+
451
+ stepped, complete, reset = self.inprogress_constraint.update(token_id)
452
+ if reset:
453
+ # 1. If the next token breaks the progress, then we must restart.
454
+ # e.g. constraint = "I love pies" and sequence so far is "I love" but `token_id` == "books".
455
+
456
+ # But that doesn't mean we self.init_state(), since we only reset the state for this particular
457
+ # constraint, not the full list of constraints.
458
+
459
+ self.pending_constraints.append(self.inprogress_constraint.copy(stateful=False))
460
+ self.inprogress_constraint = None
461
+
462
+ if complete:
463
+ # 2. If the next token completes the constraint, move it to completed list, set
464
+ # inprogress to None. If there are no pending constraints either, then this full list of constraints
465
+ # is complete.
466
+
467
+ self.complete_constraints.append(self.inprogress_constraint)
468
+ self.inprogress_constraint = None
469
+
470
+ if len(self.pending_constraints) == 0:
471
+ # we're done!
472
+ self.completed = True
473
+
474
+ else:
475
+ # Not in the middle of fulfilling a constraint. So does this `token_id` helps us step towards any of our list
476
+ # of constraints?
477
+
478
+ for cidx, pending_constraint in enumerate(self.pending_constraints):
479
+ if pending_constraint.does_advance(token_id):
480
+ stepped, complete, reset = pending_constraint.update(token_id)
481
+
482
+ if not stepped:
483
+ raise Exception(
484
+ "`constraint.update(token_id)` is not yielding incremental progress, "
485
+ "even though `constraint.does_advance(token_id)` is true."
486
+ )
487
+
488
+ if complete:
489
+ self.complete_constraints.append(pending_constraint)
490
+ self.inprogress_constraint = None
491
+
492
+ if not complete and stepped:
493
+ self.inprogress_constraint = pending_constraint
494
+
495
+ if complete or stepped:
496
+ # If we made any progress at all, then it's at least not a "pending constraint".
497
+
498
+ self.pending_constraints = (
499
+ self.pending_constraints[:cidx] + self.pending_constraints[cidx + 1 :]
500
+ )
501
+
502
+ if len(self.pending_constraints) == 0 and self.inprogress_constraint is None:
503
+ # If there's no longer any pending after this and no inprogress either, then we must be
504
+ # complete.
505
+
506
+ self.completed = True
507
+
508
+ break # prevent accidentally stepping through multiple constraints with just one token.
509
+
510
+ return complete, stepped
511
+
512
+ def copy(self, stateful=True):
513
+ new_state = ConstraintListState(self.constraints) # we actually never though self.constraints objects
514
+ # throughout this process. So it's at initialization state.
515
+
516
+ if stateful:
517
+ new_state.complete_constraints = [
518
+ constraint.copy(stateful=True) for constraint in self.complete_constraints
519
+ ]
520
+ if self.inprogress_constraint is not None:
521
+ new_state.inprogress_constraint = self.inprogress_constraint.copy(stateful=True)
522
+ new_state.pending_constraints = [constraint.copy() for constraint in self.pending_constraints]
523
+
524
+ return new_state
custom_generate/beam_search.py ADDED
@@ -0,0 +1,716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The HuggingFace Inc. team
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from abc import ABC, abstractmethod
17
+ from collections import UserDict
18
+ from typing import Optional, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ from transformers.utils import add_start_docstrings
24
+ from .beam_constraints import Constraint, ConstraintListState
25
+
26
+
27
+ PROCESS_INPUTS_DOCSTRING = r"""
28
+ Args:
29
+ input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
30
+ Indices of input sequence tokens in the vocabulary.
31
+
32
+ Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
33
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
34
+
35
+ [What are input IDs?](../glossary#input-ids)
36
+ next_scores (`torch.FloatTensor` of shape `(batch_size, 2 * num_beams)`):
37
+ Current scores of the top `2 * num_beams` non-finished beam hypotheses.
38
+ next_tokens (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
39
+ `input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses.
40
+ next_indices (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
41
+ Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
42
+ pad_token_id (`int`, *optional*):
43
+ The id of the *padding* token.
44
+ eos_token_id (`Union[int, list[int]]`, *optional*):
45
+ The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
46
+ beam_indices (`torch.LongTensor`, *optional*):
47
+ Beam indices indicating to which beam hypothesis each token correspond.
48
+ group_index (`int`, *optional*):
49
+ The index of the group of beams. Used with [`~PreTrainedModel.group_beam_search`].
50
+
51
+ Return:
52
+ `UserDict`: A dictionary composed of the fields as defined above:
53
+
54
+ - **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of all
55
+ non-finished beams.
56
+ - **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be added
57
+ to the non-finished beam_hypotheses.
58
+ - **next_beam_indices** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Beam indices
59
+ indicating to which beam the next tokens shall be added.
60
+
61
+ """
62
+
63
+ FINALIZE_INPUTS_DOCSTRING = r"""
64
+ Args:
65
+ input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
66
+ Indices of input sequence tokens in the vocabulary.
67
+
68
+ Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
69
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
70
+
71
+ [What are input IDs?](../glossary#input-ids)
72
+ final_beam_scores (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
73
+ The final scores of all non-finished beams.
74
+ final_beam_tokens (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
75
+ The last tokens to be added to the non-finished beam_hypotheses.
76
+ final_beam_indices (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
77
+ The beam indices indicating to which beam the `final_beam_tokens` shall be added.
78
+ pad_token_id (`int`, *optional*):
79
+ The id of the *padding* token.
80
+ eos_token_id (`Union[int, list[int]]`, *optional*):
81
+ The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
82
+
83
+ Return:
84
+ `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences.
85
+ The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early
86
+ due to the `eos_token_id`.
87
+
88
+ """
89
+
90
+
91
+ class BeamScorer(ABC):
92
+ """
93
+ Abstract base class for all beam scorers that are used for [`~PreTrainedModel.beam_search`] and
94
+ [`~PreTrainedModel.beam_sample`].
95
+ """
96
+
97
+ @abstractmethod
98
+ @add_start_docstrings(PROCESS_INPUTS_DOCSTRING)
99
+ def process(
100
+ self,
101
+ input_ids: torch.LongTensor,
102
+ next_scores: torch.FloatTensor,
103
+ next_tokens: torch.LongTensor,
104
+ next_indices: torch.LongTensor,
105
+ **kwargs,
106
+ ) -> tuple[torch.Tensor]:
107
+ raise NotImplementedError("This is an abstract method.")
108
+
109
+ @abstractmethod
110
+ @add_start_docstrings(FINALIZE_INPUTS_DOCSTRING)
111
+ def finalize(
112
+ self,
113
+ input_ids: torch.LongTensor,
114
+ next_scores: torch.FloatTensor,
115
+ next_tokens: torch.LongTensor,
116
+ next_indices: torch.LongTensor,
117
+ max_length: int,
118
+ **kwargs,
119
+ ) -> torch.LongTensor:
120
+ raise NotImplementedError("This is an abstract method.")
121
+
122
+ class ConstrainedBeamSearchScorer(BeamScorer):
123
+ r"""
124
+ [`BeamScorer`] implementing constrained beam search decoding.
125
+
126
+
127
+ Args:
128
+ batch_size (`int`):
129
+ Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
130
+ num_beams (`int`):
131
+ Number of beams for beam search.
132
+ constraints (`list[Constraint]`):
133
+ A list of positive constraints represented as `Constraint` objects that must be fulfilled in the generation
134
+ output. For more information, the documentation of [`Constraint`] should be read.
135
+ device (`torch.device`):
136
+ Defines the device type (*e.g.*, `"cpu"` or `"cuda"`) on which this instance of `BeamSearchScorer` will be
137
+ allocated.
138
+ length_penalty (`float`, *optional*, defaults to 1.0):
139
+ Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
140
+ the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
141
+ likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
142
+ `length_penalty` < 0.0 encourages shorter sequences.
143
+ do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
144
+ Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
145
+ `True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
146
+ heuristic is applied and the generation stops when is it very unlikely to find better candidates;
147
+ `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
148
+ beam search algorithm).
149
+ num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
150
+ The number of beam hypotheses that shall be returned upon calling
151
+ [`~transformers.BeamSearchScorer.finalize`].
152
+ num_beam_groups (`int`, *optional*, defaults to 1):
153
+ Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
154
+ See [this paper](https://huggingface.co/papers/1610.02424) for more details.
155
+ max_length (`int`, *optional*):
156
+ The maximum length of the sequence to be generated.
157
+ """
158
+
159
+ def __init__(
160
+ self,
161
+ batch_size: int,
162
+ num_beams: int,
163
+ constraints: list[Constraint],
164
+ device: torch.device,
165
+ length_penalty: Optional[float] = 1.0,
166
+ do_early_stopping: Optional[Union[bool, str]] = False,
167
+ num_beam_hyps_to_keep: Optional[int] = 1,
168
+ num_beam_groups: Optional[int] = 1,
169
+ max_length: Optional[int] = None,
170
+ ):
171
+ self.num_beams = num_beams
172
+ self.device = device
173
+ self.length_penalty = length_penalty
174
+ self.do_early_stopping = do_early_stopping
175
+ self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
176
+ self.num_beam_groups = num_beam_groups
177
+ self.group_size = self.num_beams // self.num_beam_groups
178
+ self.constraints = constraints
179
+
180
+ self._is_init = False
181
+ self._beam_hyps = [
182
+ BeamHypotheses(
183
+ num_beams=self.num_beams,
184
+ length_penalty=self.length_penalty,
185
+ early_stopping=self.do_early_stopping,
186
+ max_length=max_length,
187
+ )
188
+ for _ in range(batch_size)
189
+ ]
190
+ self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
191
+
192
+ if not isinstance(num_beams, int) or num_beams <= 1:
193
+ raise ValueError(
194
+ f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
195
+ " one should make use of `greedy_search` instead."
196
+ )
197
+
198
+ if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
199
+ raise ValueError(
200
+ "`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
201
+ f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
202
+ )
203
+
204
+ @property
205
+ def is_done(self) -> bool:
206
+ return self._done.all()
207
+
208
+ def make_constraint_states(self, n):
209
+ return [ConstraintListState([constraint.copy() for constraint in self.constraints]) for _ in range(n)]
210
+
211
+ def check_completes_constraints(self, sequence):
212
+ new_state = self.make_constraint_states(1)[0]
213
+ new_state.reset(sequence)
214
+ return new_state.completed
215
+
216
+ def process(
217
+ self,
218
+ input_ids: torch.LongTensor,
219
+ next_scores: torch.FloatTensor,
220
+ next_tokens: torch.LongTensor,
221
+ next_indices: torch.LongTensor,
222
+ scores_for_all_vocab: torch.FloatTensor,
223
+ pad_token_id: Optional[Union[int, torch.Tensor]] = None,
224
+ eos_token_id: Optional[Union[int, list[int], torch.Tensor]] = None,
225
+ beam_indices: Optional[torch.LongTensor] = None,
226
+ decoder_prompt_len: Optional[int] = 0,
227
+ ) -> tuple[torch.Tensor]:
228
+ r"""
229
+ Args:
230
+ input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
231
+ Indices of input sequence tokens in the vocabulary.
232
+
233
+ Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
234
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
235
+
236
+ [What are input IDs?](../glossary#input-ids)
237
+ next_scores (`torch.FloatTensor` of shape `(batch_size, 2 * num_beams)`):
238
+ Current scores of the top `2 * num_beams` non-finished beam hypotheses.
239
+ next_tokens (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
240
+ `input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses.
241
+ next_indices (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
242
+ Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
243
+ scores_for_all_vocab (`torch.FloatTensor` of shape `(batch_size * num_beams, sequence_length)`):
244
+ The scores of all tokens in the vocabulary for each of the beam hypotheses.
245
+ pad_token_id (`int`, *optional*):
246
+ The id of the *padding* token.
247
+ eos_token_id (`Union[int, list[int]]`, *optional*):
248
+ The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
249
+ beam_indices (`torch.LongTensor`, *optional*):
250
+ Beam indices indicating to which beam hypothesis each token correspond.
251
+ decoder_prompt_len (`int`, *optional*):
252
+ The length of prompt that is included in the input to decoder.
253
+ Return:
254
+ `UserDict`: A dictionary composed of the fields as defined above:
255
+
256
+ - **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of
257
+ all
258
+ non-finished beams.
259
+
260
+ - **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be
261
+ added
262
+ to the non-finished beam_hypotheses.
263
+ - **next_beam_indices** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Beam indices
264
+ indicating to which beam the next tokens shall be added.
265
+ """
266
+
267
+ # add up to the length which the next_scores is calculated on (including decoder prompt)
268
+ cur_len = input_ids.shape[-1] + 1
269
+ batch_size = len(self._beam_hyps)
270
+ if batch_size != (input_ids.shape[0] // self.group_size):
271
+ if self.num_beam_groups > 1:
272
+ raise ValueError(
273
+ f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
274
+ f"size of {self.group_size} is expected by the beam scorer."
275
+ )
276
+ else:
277
+ raise ValueError(
278
+ f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
279
+ f"{self.group_size} is expected by the beam scorer."
280
+ )
281
+
282
+ device = input_ids.device
283
+
284
+ next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
285
+ next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
286
+ next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
287
+
288
+ if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
289
+ if isinstance(eos_token_id, int):
290
+ eos_token_id = [eos_token_id]
291
+ eos_token_id = torch.tensor(eos_token_id)
292
+
293
+ for batch_idx, beam_hyp in enumerate(self._beam_hyps):
294
+ if self._done[batch_idx]:
295
+ if self.num_beams < len(beam_hyp):
296
+ raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
297
+ if eos_token_id is None or pad_token_id is None:
298
+ raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
299
+ # pad the batch
300
+ next_beam_scores[batch_idx, :] = 0
301
+ next_beam_tokens[batch_idx, :] = pad_token_id
302
+ next_beam_indices[batch_idx, :] = 0
303
+ continue
304
+
305
+ # next tokens for this sentence.
306
+ beam_idx = 0
307
+ for beam_token_rank, (next_token, next_score, next_index) in enumerate(
308
+ zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
309
+ ):
310
+ batch_beam_idx = batch_idx * self.group_size + next_index
311
+ # add to generated hypotheses if end of sentence
312
+ if (eos_token_id is not None) and (next_token.item() in eos_token_id):
313
+ # if beam_token does not belong to top num_beams tokens, it should not be added
314
+ is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
315
+ if is_beam_token_worse_than_top_num_beams:
316
+ continue
317
+
318
+ completes_constraint = self.check_completes_constraints(input_ids[batch_beam_idx].tolist())
319
+ if completes_constraint:
320
+ if beam_indices is not None:
321
+ beam_index = beam_indices[batch_beam_idx]
322
+ beam_index = beam_index + (batch_beam_idx,)
323
+ else:
324
+ beam_index = None
325
+
326
+ beam_hyp.add(
327
+ input_ids[batch_beam_idx].clone(),
328
+ next_score.item(),
329
+ beam_indices=beam_index,
330
+ generated_len=cur_len - decoder_prompt_len,
331
+ )
332
+ else:
333
+ # add next predicted token since it is not eos_token
334
+ next_beam_scores[batch_idx, beam_idx] = next_score
335
+ next_beam_tokens[batch_idx, beam_idx] = next_token
336
+ next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
337
+ beam_idx += 1
338
+
339
+ # once the beam for next step is full, don't add more tokens to it.
340
+ if beam_idx == self.group_size:
341
+ break
342
+
343
+ new_scores, new_tokens, new_indices = self.step_sentence_constraint(
344
+ batch_idx,
345
+ input_ids,
346
+ scores_for_all_vocab,
347
+ next_beam_scores[batch_idx],
348
+ next_beam_tokens[batch_idx],
349
+ next_beam_indices[batch_idx],
350
+ )
351
+
352
+ next_beam_scores[batch_idx] = new_scores
353
+ next_beam_tokens[batch_idx] = new_tokens
354
+ next_beam_indices[batch_idx] = new_indices
355
+
356
+ if beam_idx < self.group_size:
357
+ raise ValueError(
358
+ f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
359
+ f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
360
+ )
361
+
362
+ # Check if we are done so that we can save a pad step if all(done)
363
+ self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
364
+ next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
365
+ )
366
+
367
+ return UserDict(
368
+ {
369
+ "next_beam_scores": next_beam_scores.view(-1),
370
+ "next_beam_tokens": next_beam_tokens.view(-1),
371
+ "next_beam_indices": next_beam_indices.view(-1),
372
+ }
373
+ )
374
+
375
+ def step_sentence_constraint(
376
+ self,
377
+ batch_idx: int,
378
+ input_ids: torch.LongTensor,
379
+ vocab_scores: torch.FloatTensor,
380
+ sent_beam_scores: torch.FloatTensor,
381
+ sent_beam_tokens: torch.LongTensor,
382
+ sent_beam_indices: torch.LongTensor,
383
+ push_progress: bool = False,
384
+ ):
385
+ # sent_beam_tokens are the next {num_beams} number of tokens that are under consideration for this beam
386
+ # (candidate next tokens)
387
+
388
+ # 1. Adding "advance_tokens"
389
+ # using ConstraintStateList.advance(), we propose new tokens to be added into this "candidate list" that will
390
+ # advance us in fulfilling the constraints.
391
+
392
+ # 2. Selecting best candidates such that we end up with highest probable candidates
393
+ # that fulfill our constraints.
394
+
395
+ orig_len = sent_beam_indices.size(0)
396
+ device = sent_beam_indices.device
397
+
398
+ # initialize states
399
+ topk_contraint_states = self.make_constraint_states(orig_len)
400
+ advance_constraint_states = self.make_constraint_states(orig_len)
401
+
402
+ sidx, eidx = batch_idx * orig_len, (batch_idx + 1) * orig_len
403
+ this_batch_input_ids = input_ids[sidx:eidx]
404
+ this_batch_token_scores = vocab_scores[sidx:eidx]
405
+ full_hypotheses = torch.cat((input_ids[sent_beam_indices], sent_beam_tokens.unsqueeze(-1)), dim=-1)
406
+
407
+ # need to make new hypothesis that advance the constraints
408
+ track_new = {
409
+ "new_seqs": full_hypotheses.tolist(),
410
+ "new_states": [],
411
+ "new_indices": [],
412
+ "new_tokens": [],
413
+ "new_scores": [],
414
+ }
415
+ for seq_idx, pre_seq in enumerate(this_batch_input_ids):
416
+ # pre_seq = ith sequence generated before this step.
417
+
418
+ # input_ids -> (topk) generic beam search best model next tokens
419
+ # -> (advance) constraints forcing the next token
420
+ # either way, we need to sort them into "banks" later, so store a "ConstraintListState" for all types of
421
+ # hypotheses.
422
+
423
+ topk_state = topk_contraint_states[seq_idx]
424
+ topk_state.reset(full_hypotheses[seq_idx].tolist())
425
+
426
+ advance_state = advance_constraint_states[seq_idx]
427
+ advance_state.reset(pre_seq.tolist())
428
+
429
+ if not advance_state.completed:
430
+ advance_tokens = torch.tensor(advance_state.advance(), dtype=torch.long, device=device)
431
+ for advance_token in advance_tokens:
432
+ # since adding each `advance_token` leads to a different hypothesis, create new state instance.
433
+ new_state = advance_state.copy(stateful=True)
434
+ new_state.add(advance_token.tolist())
435
+
436
+ advance_seq = torch.cat((pre_seq, advance_token.unsqueeze(0)), -1).tolist()
437
+ if advance_seq not in track_new["new_seqs"]:
438
+ # prevent duplicates, which are basically bound to happen in this process.
439
+ track_new["new_seqs"].append(advance_seq)
440
+ track_new["new_indices"].append(sidx + seq_idx) # idx -> global idx across all the batches
441
+ track_new["new_tokens"].append(advance_token)
442
+ track_new["new_scores"].append(this_batch_token_scores[seq_idx].take(advance_token))
443
+ track_new["new_states"].append(new_state)
444
+ elif push_progress:
445
+ # Basically, `sent_beam_indices` often chooses very little among `input_ids` the generated sequences that
446
+ # actually fulfill our constraints. For example, let constraints == ["loves pies"] and
447
+
448
+ # pre_seq_1 = "The child loves pies and" pre_seq_2 = "The child plays in the playground and"
449
+
450
+ # Without this step, if `sent_beam_indices` is something like [1,1], then
451
+ # 1. `pre_seq_1` won't be added to the list of (topk) hypothesis since it's not in the indices and
452
+ # 2. it won't be added to the list of (advance) hypothesis since it's completed already. (this is
453
+ # the else part of `if constraints_completed[seq_idx]`)
454
+ # 3. it ends up simply getting removed from consideration.
455
+
456
+ # #3 might be fine and actually desired, since it's likely that it's a low-probability output anyways,
457
+ # especially if it's not in the list of `sent_beam_indices`. But this often leads to lengthened beam
458
+ # search times, since completed sequences keep getting removed after all this effort for constrained
459
+ # generation.
460
+
461
+ # Here, we basically take `pre_seq_1` and to "push" it into the considered list of hypotheses, by simply
462
+ # appending the next likely token in the vocabulary and adding it to the list of hypotheses.
463
+
464
+ new_score, new_token = torch.max(this_batch_token_scores[seq_idx], 0) # some next probable token
465
+ advance_seq = torch.cat((pre_seq, new_token.unsqueeze(0)), -1)
466
+
467
+ advance_state = advance_constraint_states[seq_idx]
468
+
469
+ advance_seq = advance_seq.tolist()
470
+
471
+ advance_state.reset(advance_seq)
472
+ if advance_seq not in track_new["new_seqs"]:
473
+ # but still don't want to have duplicates
474
+ track_new["new_seqs"].append(advance_seq)
475
+ track_new["new_indices"].append(seq_idx)
476
+ track_new["new_tokens"].append(new_token)
477
+ track_new["new_scores"].append(new_score)
478
+ track_new["new_states"].append(advance_state)
479
+
480
+ if len(track_new["new_indices"]) > 0:
481
+ new_indices = torch.tensor(track_new["new_indices"], device=device)
482
+ new_tokens = torch.stack(track_new["new_tokens"]).to(device)
483
+ new_scores = torch.stack(track_new["new_scores"]).to(device)
484
+
485
+ all_states = topk_contraint_states + track_new["new_states"]
486
+ all_tokens = torch.cat((sent_beam_tokens, new_tokens), -1)
487
+ all_scores = torch.cat((sent_beam_scores, new_scores), -1)
488
+ all_banks = torch.tensor([one.get_bank() for one in all_states], device=device)
489
+
490
+ zipped = all_banks * 100 + all_scores
491
+ indices = zipped.sort(descending=True).indices
492
+ sorted_banks = all_banks[indices]
493
+
494
+ # Then we end up with {sorted among bank C}, {sorted among bank C-1}, ..., {sorted among bank 0}
495
+
496
+ counter = -1
497
+ cur_bank = sorted_banks[0]
498
+ increments = []
499
+ for bank in sorted_banks:
500
+ if bank == cur_bank:
501
+ counter += 1
502
+ else:
503
+ counter = 0
504
+ cur_bank = bank
505
+ increments.append(counter)
506
+ rearrangers = torch.tensor(np.argsort(increments, kind="mergesort"))
507
+
508
+ indices = indices[rearrangers][:orig_len]
509
+
510
+ sent_beam_scores = all_scores[indices]
511
+ sent_beam_tokens = all_tokens[indices]
512
+ sent_beam_indices = torch.cat((sent_beam_indices, new_indices))[indices]
513
+
514
+ return sent_beam_scores, sent_beam_tokens, sent_beam_indices
515
+
516
+ def finalize(
517
+ self,
518
+ input_ids: torch.LongTensor,
519
+ final_beam_scores: torch.FloatTensor,
520
+ final_beam_tokens: torch.LongTensor,
521
+ final_beam_indices: torch.LongTensor,
522
+ max_length: int,
523
+ pad_token_id: Optional[Union[int, torch.Tensor]] = None,
524
+ eos_token_id: Optional[Union[int, list[int], torch.Tensor]] = None,
525
+ beam_indices: Optional[torch.LongTensor] = None,
526
+ decoder_prompt_len: Optional[int] = 0,
527
+ ) -> tuple[torch.LongTensor]:
528
+ batch_size = len(self._beam_hyps)
529
+
530
+ if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
531
+ if isinstance(eos_token_id, int):
532
+ eos_token_id = [eos_token_id]
533
+ eos_token_id = torch.tensor(eos_token_id)
534
+
535
+ # finalize all open beam hypotheses and add to generated hypotheses
536
+ for batch_idx, beam_hyp in enumerate(self._beam_hyps):
537
+ if self._done[batch_idx]:
538
+ continue
539
+
540
+ # all open beam hypotheses are added to the beam hypothesis
541
+ # beam hypothesis class automatically keeps the best beams
542
+
543
+ ids_collect = []
544
+ for beam_id in range(self.num_beams):
545
+ batch_beam_idx = batch_idx * self.num_beams + beam_id
546
+ final_score = final_beam_scores[batch_beam_idx].item()
547
+ final_tokens = input_ids[batch_beam_idx]
548
+
549
+ completes_constraint = self.check_completes_constraints(final_tokens.tolist())
550
+ if completes_constraint:
551
+ beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
552
+ generated_len = final_tokens.shape[-1] - decoder_prompt_len
553
+ beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
554
+ ids_collect.append(beam_id)
555
+
556
+ # due to overly complex constraints or other factors, sometimes we can't guarantee a successful
557
+ # generation. In these cases we simply return the highest scoring outputs.
558
+ if len(ids_collect) < self.num_beam_hyps_to_keep:
559
+ for beam_id in range(self.num_beams):
560
+ if beam_id not in ids_collect:
561
+ batch_beam_idx = batch_idx * self.num_beams + beam_id
562
+ final_score = final_beam_scores[batch_beam_idx].item()
563
+ final_tokens = input_ids[batch_beam_idx]
564
+ generated_len = final_tokens.shape[-1] - decoder_prompt_len
565
+ beam_hyp.add(final_tokens, final_score, generated_len=generated_len)
566
+ if len(ids_collect) >= self.num_beam_hyps_to_keep:
567
+ break
568
+
569
+ # select the best hypotheses
570
+ sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
571
+ best = []
572
+ best_indices = []
573
+ best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
574
+
575
+ # retrieve best hypotheses
576
+ for i, beam_hyp in enumerate(self._beam_hyps):
577
+ sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
578
+ for j in range(self.num_beam_hyps_to_keep):
579
+ best_hyp_tuple = sorted_hyps.pop()
580
+ best_score = best_hyp_tuple[0]
581
+ best_hyp = best_hyp_tuple[1]
582
+ best_index = best_hyp_tuple[2]
583
+ sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
584
+
585
+ # append to lists
586
+ best.append(best_hyp)
587
+
588
+ # append indices to list
589
+ best_indices.append(best_index)
590
+
591
+ best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
592
+
593
+ # prepare for adding eos
594
+ sent_lengths_max = sent_lengths.max().item() + 1
595
+
596
+ sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
597
+ decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
598
+
599
+ if len(best_indices) > 0 and best_indices[0] is not None:
600
+ indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
601
+ else:
602
+ indices = None
603
+
604
+ # shorter batches are padded if needed
605
+ if sent_lengths.min().item() != sent_lengths.max().item():
606
+ if pad_token_id is None:
607
+ raise ValueError("`pad_token_id` has to be defined")
608
+ decoded.fill_(pad_token_id)
609
+
610
+ if indices is not None:
611
+ indices.fill_(-1)
612
+
613
+ # fill with hypotheses and eos_token_id if the latter fits in
614
+ for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
615
+ decoded[i, : sent_lengths[i]] = hypo
616
+
617
+ if indices is not None:
618
+ indices[i, : len(best_idx)] = torch.tensor(best_idx)
619
+
620
+ if sent_lengths[i] < sent_max_len:
621
+ # inserting only the first eos_token_id
622
+ decoded[i, sent_lengths[i]] = eos_token_id[0]
623
+
624
+ return UserDict(
625
+ {
626
+ "sequences": decoded,
627
+ "sequence_scores": best_scores,
628
+ "beam_indices": indices,
629
+ }
630
+ )
631
+
632
+
633
+ class BeamHypotheses:
634
+ def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool, max_length: Optional[int] = None):
635
+ """
636
+ Initialize n-best list of hypotheses.
637
+ """
638
+ self.length_penalty = length_penalty
639
+ self.early_stopping = early_stopping
640
+ self.max_length = max_length
641
+ self.num_beams = num_beams
642
+ self.beams = []
643
+ self.worst_score = 1e9
644
+
645
+ if not isinstance(self.early_stopping, bool) and self.max_length is None:
646
+ raise ValueError(
647
+ "When `do_early_stopping` is set to a string, `max_length` must be defined. Ensure it is passed to the"
648
+ " BeamScorer class instance at initialization time."
649
+ )
650
+
651
+ def __len__(self):
652
+ """
653
+ Number of hypotheses in the list.
654
+ """
655
+ return len(self.beams)
656
+
657
+ def add(
658
+ self,
659
+ hyp: torch.LongTensor,
660
+ sum_logprobs: float,
661
+ beam_indices: Optional[torch.LongTensor] = None,
662
+ generated_len: Optional[int] = None,
663
+ ):
664
+ """
665
+ Add a new hypothesis to the list.
666
+ """
667
+ if generated_len is not None:
668
+ score = sum_logprobs / (generated_len**self.length_penalty)
669
+ # This 'else' case exists for retrocompatibility
670
+ else:
671
+ score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
672
+
673
+ if len(self) < self.num_beams or score > self.worst_score:
674
+ self.beams.append((score, hyp, beam_indices))
675
+ if len(self) > self.num_beams:
676
+ sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
677
+ del self.beams[sorted_next_scores[0][1]]
678
+ self.worst_score = sorted_next_scores[1][0]
679
+ else:
680
+ self.worst_score = min(score, self.worst_score)
681
+
682
+ def is_done(self, best_sum_logprobs: float, cur_len: int, decoder_prompt_len: Optional[int] = 0) -> bool:
683
+ """
684
+ If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
685
+ one in the heap, then we are done with this sentence.
686
+ """
687
+
688
+ if len(self) < self.num_beams:
689
+ return False
690
+
691
+ # `True`: stop as soon as at least `num_beams` hypotheses are finished
692
+ if self.early_stopping is True:
693
+ return True
694
+ # `False`: heuristic -- compute best possible score from `cur_len`, even though it is not entirely accurate
695
+ # when `length_penalty` is positive. See the discussion below for more details.
696
+ # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
697
+ elif self.early_stopping is False:
698
+ highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
699
+ ret = self.worst_score >= highest_attainable_score
700
+ return ret
701
+ # `"never"`: compute the best possible score, depending on the signal of `length_penalty`
702
+ else:
703
+ # `length_penalty` > 0.0 -> max denominator is obtaned from `max_length`, not from `cur_len` -> min
704
+ # abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain
705
+ # its max this way
706
+ if self.length_penalty > 0.0:
707
+ if self.max_length <= decoder_prompt_len:
708
+ raise ValueError("max_length is not larger than decoder prompt length")
709
+ highest_attainable_score = (
710
+ best_sum_logprobs / (self.max_length - decoder_prompt_len) ** self.length_penalty
711
+ )
712
+ # the opposite logic applies here (max `highest_attainable_score` from `cur_len`)
713
+ else:
714
+ highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
715
+ ret = self.worst_score >= highest_attainable_score
716
+ return ret
custom_generate/generate.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+ import torch
3
+ from transformers import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
4
+ from transformers.generation.utils import (
5
+ GenerateBeamOutput,
6
+ GenerationMixin,
7
+ GenerateBeamDecoderOnlyOutput,
8
+ GenerateBeamEncoderDecoderOutput,
9
+ )
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import numpy as np
13
+ import logging
14
+
15
+ from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
16
+ from .beam_search import ConstrainedBeamSearchScorer
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ def _constrained_beam_search(
21
+ model,
22
+ input_ids: torch.LongTensor,
23
+ logits_processor: LogitsProcessorList,
24
+ stopping_criteria: StoppingCriteriaList,
25
+ generation_config: GenerationConfig,
26
+ synced_gpus: bool,
27
+ **model_kwargs,
28
+ ) -> Union[GenerateBeamOutput, torch.LongTensor]:
29
+ r"""
30
+ Generates sequences of token ids for models with a language modeling head using **constrained beam search
31
+ decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
32
+
33
+ Parameters:
34
+ input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`):
35
+ The sequence used as a prompt for the generation.
36
+ logits_processor (`LogitsProcessorList`):
37
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
38
+ used to modify the prediction scores of the language modeling head applied at each generation step.
39
+ stopping_criteria (`StoppingCriteriaList`):
40
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
41
+ used to tell if the generation loop should stop.
42
+ generation_config ([`~generation.GenerationConfig`]):
43
+ The generation configuration to be used as parametrization of the decoding method.
44
+ synced_gpus (`bool`):
45
+ Whether to continue running the while loop until max_length (needed to avoid deadlocking with
46
+ `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
47
+ model_kwargs:
48
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
49
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
50
+
51
+ Return:
52
+ [`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
53
+ `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
54
+ [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
55
+ `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
56
+ `model.config.is_encoder_decoder=True`.
57
+ """
58
+ if generation_config.constraints is not None or generation_config.force_words_ids is not None:
59
+ constrained_wrong_parameter_msg = (
60
+ "one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. "
61
+ "However, `{flag_name}` is set to `{flag_value}`, which is incompatible with this generation "
62
+ "mode. Set `constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue."
63
+ )
64
+ if generation_config.do_sample is True:
65
+ raise ValueError(
66
+ constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=generation_config.do_sample)
67
+ )
68
+
69
+ final_constraints = []
70
+ if generation_config.constraints is not None:
71
+ final_constraints = generation_config.constraints
72
+
73
+ if generation_config.force_words_ids is not None:
74
+
75
+ def typeerror():
76
+ raise ValueError(
77
+ "`force_words_ids` has to either be a `list[list[list[int]]]` or `list[list[int]]` "
78
+ f"of positive integers, but is {generation_config.force_words_ids}."
79
+ )
80
+
81
+ if (
82
+ not isinstance(generation_config.force_words_ids, list)
83
+ or len(generation_config.force_words_ids) == 0
84
+ ):
85
+ typeerror()
86
+
87
+ for word_ids in generation_config.force_words_ids:
88
+ if isinstance(word_ids[0], list):
89
+ if not isinstance(word_ids, list) or len(word_ids) == 0:
90
+ typeerror()
91
+ if any(not isinstance(token_ids, list) for token_ids in word_ids):
92
+ typeerror()
93
+ if any(
94
+ any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
95
+ for token_ids in word_ids
96
+ ):
97
+ typeerror()
98
+
99
+ constraint = DisjunctiveConstraint(word_ids)
100
+ else:
101
+ if not isinstance(word_ids, list) or len(word_ids) == 0:
102
+ typeerror()
103
+ if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):
104
+ typeerror()
105
+
106
+ constraint = PhrasalConstraint(word_ids)
107
+ final_constraints.append(constraint)
108
+ # define beam scorer
109
+ constrained_beam_scorer = ConstrainedBeamSearchScorer(
110
+ constraints=final_constraints,
111
+ batch_size=batch_size,
112
+ num_beams=generation_config.num_beams,
113
+ device=inputs_tensor.device,
114
+ length_penalty=generation_config.length_penalty,
115
+ do_early_stopping=generation_config.early_stopping,
116
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
117
+ max_length=generation_config.max_length,
118
+ )
119
+ # init values
120
+ pad_token_id = generation_config._pad_token_tensor
121
+ eos_token_id = generation_config._eos_token_tensor
122
+ output_attentions = generation_config.output_attentions
123
+ output_hidden_states = generation_config.output_hidden_states
124
+ output_scores = generation_config.output_scores
125
+ output_logits = generation_config.output_logits
126
+ return_dict_in_generate = generation_config.return_dict_in_generate
127
+
128
+ batch_size = len(constrained_beam_scorer._beam_hyps)
129
+ num_beams = constrained_beam_scorer.num_beams
130
+
131
+ batch_beam_size, cur_len = input_ids.shape[:2]
132
+ model_kwargs = model._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
133
+
134
+ if num_beams * batch_size != batch_beam_size:
135
+ raise ValueError(
136
+ f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
137
+ )
138
+
139
+ # init attention / hidden states / scores tuples
140
+ scores = () if (return_dict_in_generate and output_scores) else None
141
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
142
+ beam_indices = (
143
+ tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
144
+ )
145
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
146
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
147
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
148
+
149
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
150
+ if return_dict_in_generate and model.config.is_encoder_decoder:
151
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
152
+ encoder_hidden_states = (
153
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
154
+ )
155
+
156
+ # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
157
+ # of the first beam are considered to avoid sampling the exact same tokens across all beams.
158
+ beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
159
+ beam_scores[:, 1:] = -1e9
160
+ beam_scores = beam_scores.view((batch_size * num_beams,))
161
+
162
+ this_peer_finished = False
163
+
164
+ decoder_prompt_len = input_ids.shape[1] # record the prompt length of decoder
165
+ while model._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
166
+ model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
167
+
168
+ # prepare variable output controls (note: some models won't accept all output controls)
169
+ model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
170
+ model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
171
+
172
+ outputs = model(**model_inputs, return_dict=True)
173
+
174
+ # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
175
+ model_kwargs = model._update_model_kwargs_for_generation(
176
+ outputs,
177
+ model_kwargs,
178
+ is_encoder_decoder=model.config.is_encoder_decoder,
179
+ )
180
+ if synced_gpus and this_peer_finished:
181
+ cur_len = cur_len + 1
182
+ continue
183
+
184
+ # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
185
+ # (the clone itself is always small)
186
+ # .float() is needed to retain precision for later logits manipulations
187
+ next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
188
+ next_token_scores = nn.functional.log_softmax(
189
+ next_token_logits, dim=-1
190
+ ) # (batch_size * num_beams, vocab_size)
191
+
192
+ next_token_scores_processed = logits_processor(input_ids, next_token_scores)
193
+
194
+ next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
195
+ next_token_scores_processed
196
+ )
197
+
198
+ scores_for_all_vocab = next_token_scores.clone()
199
+
200
+ # Store scores, attentions and hidden_states when required
201
+ if return_dict_in_generate:
202
+ if output_scores:
203
+ scores += (next_token_scores,)
204
+ if output_logits:
205
+ raw_logits += (next_token_logits,)
206
+ if output_attentions:
207
+ decoder_attentions += (
208
+ (outputs.decoder_attentions,) if model.config.is_encoder_decoder else (outputs.attentions,)
209
+ )
210
+ if model.config.is_encoder_decoder:
211
+ cross_attentions += (outputs.cross_attentions,)
212
+
213
+ if output_hidden_states:
214
+ decoder_hidden_states += (
215
+ (outputs.decoder_hidden_states,)
216
+ if model.config.is_encoder_decoder
217
+ else (outputs.hidden_states,)
218
+ )
219
+
220
+ # reshape for beam search
221
+ vocab_size = next_token_scores.shape[-1]
222
+ next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
223
+
224
+ # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
225
+ n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
226
+ next_token_scores, next_tokens = torch.topk(
227
+ next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True
228
+ )
229
+
230
+ next_indices = (next_tokens / vocab_size).long()
231
+ next_tokens = next_tokens % vocab_size
232
+
233
+ # stateless
234
+ beam_outputs = constrained_beam_scorer.process(
235
+ input_ids,
236
+ next_token_scores,
237
+ next_tokens,
238
+ next_indices,
239
+ scores_for_all_vocab,
240
+ pad_token_id=pad_token_id,
241
+ eos_token_id=eos_token_id,
242
+ beam_indices=beam_indices,
243
+ decoder_prompt_len=decoder_prompt_len,
244
+ )
245
+ beam_scores = beam_outputs["next_beam_scores"]
246
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
247
+ beam_idx = beam_outputs["next_beam_indices"]
248
+
249
+ input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
250
+
251
+ # This is needed to properly delete outputs.logits which may be very large for first iteration
252
+ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
253
+ # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
254
+ # (that way the memory peak does not include outputs.logits)
255
+ del outputs
256
+
257
+ # NOTE: we need to check if `model._reorder_cache` exists for special models like RAG, RecurrentGemma etc.
258
+ if model_kwargs.get("past_key_values", None) is not None:
259
+ if hasattr(model, "_reorder_cache"):
260
+ model_kwargs["past_key_values"] = model._reorder_cache(model_kwargs["past_key_values"], beam_idx)
261
+ else:
262
+ model_kwargs["past_key_values"].reorder_cache(beam_idx)
263
+
264
+ if return_dict_in_generate and output_scores:
265
+ beam_indices = tuple(beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))
266
+
267
+ # increase cur_len
268
+ cur_len = cur_len + 1
269
+
270
+ if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
271
+ this_peer_finished = True
272
+
273
+ sequence_outputs = constrained_beam_scorer.finalize(
274
+ input_ids,
275
+ beam_scores,
276
+ next_tokens,
277
+ next_indices,
278
+ pad_token_id=pad_token_id,
279
+ eos_token_id=eos_token_id,
280
+ max_length=stopping_criteria.max_length,
281
+ beam_indices=beam_indices,
282
+ decoder_prompt_len=decoder_prompt_len,
283
+ )
284
+
285
+ if return_dict_in_generate:
286
+ if not output_scores:
287
+ sequence_outputs["sequence_scores"] = None
288
+ if model.config.is_encoder_decoder:
289
+ return GenerateBeamEncoderDecoderOutput(
290
+ sequences=sequence_outputs["sequences"],
291
+ sequences_scores=sequence_outputs["sequence_scores"],
292
+ scores=scores,
293
+ logits=raw_logits,
294
+ beam_indices=sequence_outputs["beam_indices"],
295
+ encoder_attentions=encoder_attentions,
296
+ encoder_hidden_states=encoder_hidden_states,
297
+ decoder_attentions=decoder_attentions,
298
+ cross_attentions=cross_attentions,
299
+ decoder_hidden_states=decoder_hidden_states,
300
+ past_key_values=model_kwargs.get("past_key_values"),
301
+ )
302
+ else:
303
+ return GenerateBeamDecoderOnlyOutput(
304
+ sequences=sequence_outputs["sequences"],
305
+ sequences_scores=sequence_outputs["sequence_scores"],
306
+ scores=scores,
307
+ logits=raw_logits,
308
+ beam_indices=sequence_outputs["beam_indices"],
309
+ attentions=decoder_attentions,
310
+ hidden_states=decoder_hidden_states,
311
+ past_key_values=model_kwargs.get("past_key_values"),
312
+ )
313
+ else:
314
+ return sequence_outputs["sequences"]
315
+
316
+ def generate(model, *args, **kwargs):
317
+ """Custom generate function for constrained beam search decoding.
318
+ Args:
319
+ model (`PreTrainedModel`):
320
+ The model to generate from.
321
+ num_beams (`int`): The number of beams to use for beam search.
322
+ constraints (`list[Constraint]`, *optional*):
323
+ Custom constraints that can be added to the generation to ensure that the output will contain the use of
324
+ certain tokens as defined by `Constraint` objects, in the most sensible way possible.
325
+ force_words_ids (`list[list[list[int]]]`): List of token ids that must be generated. If given a `list[list[int]]`, this is treated as a simple list of
326
+ words that must be included, the opposite to `bad_words_ids`. If given `list[list[list[int]]]`, this
327
+ triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one
328
+ can allow different forms of each word.
329
+ length_penalty (`float`): The length penalty to use for beam search.
330
+ early_stopping (`bool`): Whether to stop beam search when sufficient beams have finished.
331
+ num_return_sequences (`int`): The number of sequences to return.
332
+ max_length (`int`): The maximum length of the generated sequence.
333
+ """
334
+ generation_outputs = GenerationMixin.generate(
335
+ model, *args, custom_generate=_constrained_beam_search, **kwargs
336
+ )
337
+ return generation_outputs
generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 151645,
6
+ 151643
7
+ ],
8
+ "pad_token_id": 151643,
9
+ "temperature": 0.6,
10
+ "top_k": 20,
11
+ "top_p": 0.95,
12
+ "transformers_version": "4.56.0"
13
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f47f71177f32bcd101b7573ec9171e6a57f4f4d31148d38e382306f42996874b
3
+ size 1503300328
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aeb13307a71acd8fe81861d94ad54ab689df773318809eed3cbe794b4492dae4
3
+ size 11422654
tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|vision_start|>",
224
+ "<|vision_end|>",
225
+ "<|vision_pad|>",
226
+ "<|image_pad|>",
227
+ "<|video_pad|>"
228
+ ],
229
+ "bos_token": null,
230
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in content %}\n {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}",
231
+ "clean_up_tokenization_spaces": false,
232
+ "eos_token": "<|im_end|>",
233
+ "errors": "replace",
234
+ "model_max_length": 131072,
235
+ "pad_token": "<|endoftext|>",
236
+ "split_special_tokens": false,
237
+ "tokenizer_class": "Qwen2Tokenizer",
238
+ "unk_token": null
239
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff