Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- .ruff_cache/.gitignore +2 -0
- .ruff_cache/0.12.8/7010951691598163845 +0 -0
- .ruff_cache/CACHEDIR.TAG +1 -0
- README.md +118 -0
- config.json +30 -0
- custom_generate/beam_constraints.py +524 -0
- custom_generate/beam_search.py +716 -0
- custom_generate/generate.py +337 -0
- generation_config.json +13 -0
- merges.txt +0 -0
- model.safetensors +3 -0
- tokenizer.json +3 -0
- tokenizer_config.json +239 -0
- vocab.json +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
.ruff_cache/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Automatically created by ruff.
|
2 |
+
*
|
.ruff_cache/0.12.8/7010951691598163845
ADDED
Binary file (149 Bytes). View file
|
|
.ruff_cache/CACHEDIR.TAG
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Signature: 8a477f597d28d172789f06886806bc55
|
README.md
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
|
3 |
+
library_name: transformers
|
4 |
+
tags:
|
5 |
+
- custom_generate
|
6 |
+
---
|
7 |
+
|
8 |
+
## Description
|
9 |
+
|
10 |
+
Constrained Beam Search extends standard beam search by allowing you to enforce lexical or phrasal constraints in the generated output. This is useful when you know certain words or phrases must appear (e.g., translation dictionaries, product names, slot values), or when multiple outputs are equally probable but only some are desirable for your use case.
|
11 |
+
|
12 |
+
Unlike ordinary beam search, constrained beam search steers generation to include required subsequences somewhere in the final output while balancing fluency.
|
13 |
+
|
14 |
+
---
|
15 |
+
|
16 |
+
## Why it's difficult
|
17 |
+
|
18 |
+
Beam search generates token-by-token and scores candidates locally. Forcing a phrase like "is fast" to appear somewhere requires the search to plan several steps ahead and decide when to insert the constrained tokens without breaking fluency. The problem becomes more complex with multiple constraints, optional alternatives, or ordering requirements.
|
19 |
+
|
20 |
+
Constrained beam search solves this by:
|
21 |
+
- Injecting constraint-progressing tokens among regular high-probability candidates
|
22 |
+
- Grouping beams into banks by how much of the constraints they satisfied
|
23 |
+
- Selecting beams round-robin across banks to balance fluency and constraint satisfaction
|
24 |
+
|
25 |
+
---
|
26 |
+
|
27 |
+
## Base model
|
28 |
+
|
29 |
+
* [Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B)
|
30 |
+
|
31 |
+
---
|
32 |
+
|
33 |
+
## Model compatibility
|
34 |
+
|
35 |
+
- Encoder-decoder and decoder-only transformer models
|
36 |
+
|
37 |
+
---
|
38 |
+
|
39 |
+
## Additional Arguments
|
40 |
+
|
41 |
+
- `constraints` (list[Constraint]): Advanced constraints, e.g., `PhrasalConstraint`, `DisjunctiveConstraint`
|
42 |
+
- `force_words_ids` (list[list[int]] | list[list[list[int]]]): Simple way to specify words/phrases or disjunctive sets
|
43 |
+
- `num_beams` (int): Beam width
|
44 |
+
- Other standard beam args: `length_penalty`, `early_stopping`, `num_return_sequences`, `max_length`
|
45 |
+
|
46 |
+
Notes:
|
47 |
+
- Constrained decoding is incompatible with sampling: set `do_sample=False`
|
48 |
+
- Tokenize constraints without adding special tokens
|
49 |
+
|
50 |
+
---
|
51 |
+
|
52 |
+
## Example 1: Forcing a word (formal German translation)
|
53 |
+
|
54 |
+
```python
|
55 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
56 |
+
|
57 |
+
tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
58 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
59 |
+
|
60 |
+
encoder_input_str = "translate English to German: How old are you?"
|
61 |
+
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
|
62 |
+
|
63 |
+
force_words = ["Sie"]
|
64 |
+
force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids
|
65 |
+
|
66 |
+
outputs = model.generate(
|
67 |
+
input_ids,
|
68 |
+
custom_generate="transformers-community/constrained-beam-search",
|
69 |
+
force_words_ids=force_words_ids,
|
70 |
+
num_beams=5,
|
71 |
+
num_return_sequences=1,
|
72 |
+
no_repeat_ngram_size=1,
|
73 |
+
remove_invalid_values=True,
|
74 |
+
trust_remote_code=True,
|
75 |
+
)
|
76 |
+
|
77 |
+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
78 |
+
```
|
79 |
+
|
80 |
+
Expected to contain the forced word: `Wie alt sind Sie?`
|
81 |
+
|
82 |
+
---
|
83 |
+
|
84 |
+
## Example 2: Disjunctive constraints (choose any of several forms)
|
85 |
+
|
86 |
+
```python
|
87 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
88 |
+
|
89 |
+
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
90 |
+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
91 |
+
|
92 |
+
force_word = "scared"
|
93 |
+
force_flexible = ["scream", "screams", "screaming", "screamed"]
|
94 |
+
|
95 |
+
force_words_ids = [
|
96 |
+
tokenizer([force_word], add_prefix_space=True, add_special_tokens=False).input_ids,
|
97 |
+
tokenizer(force_flexible, add_prefix_space=True, add_special_tokens=False).input_ids,
|
98 |
+
]
|
99 |
+
|
100 |
+
starting_text = ["The soldiers", "The child"]
|
101 |
+
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids
|
102 |
+
|
103 |
+
outputs = model.generate(
|
104 |
+
input_ids,
|
105 |
+
custom_generate="transformers-community/constrained-beam-search",
|
106 |
+
force_words_ids=force_words_ids,
|
107 |
+
num_beams=10,
|
108 |
+
num_return_sequences=1,
|
109 |
+
no_repeat_ngram_size=1,
|
110 |
+
remove_invalid_values=True,
|
111 |
+
trust_remote_code=True,
|
112 |
+
)
|
113 |
+
|
114 |
+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
115 |
+
print(tokenizer.decode(outputs[1], skip_special_tokens=True))
|
116 |
+
```
|
117 |
+
|
118 |
+
Outputs will include the mandatory word and at least one from the flexible set.
|
config.json
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"Qwen3ForCausalLM"
|
4 |
+
],
|
5 |
+
"attention_bias": false,
|
6 |
+
"attention_dropout": 0.0,
|
7 |
+
"bos_token_id": 151643,
|
8 |
+
"eos_token_id": 151645,
|
9 |
+
"head_dim": 128,
|
10 |
+
"hidden_act": "silu",
|
11 |
+
"hidden_size": 1024,
|
12 |
+
"initializer_range": 0.02,
|
13 |
+
"intermediate_size": 3072,
|
14 |
+
"max_position_embeddings": 40960,
|
15 |
+
"max_window_layers": 28,
|
16 |
+
"model_type": "qwen3",
|
17 |
+
"num_attention_heads": 16,
|
18 |
+
"num_hidden_layers": 28,
|
19 |
+
"num_key_value_heads": 8,
|
20 |
+
"rms_norm_eps": 1e-06,
|
21 |
+
"rope_scaling": null,
|
22 |
+
"rope_theta": 1000000,
|
23 |
+
"sliding_window": null,
|
24 |
+
"tie_word_embeddings": true,
|
25 |
+
"torch_dtype": "bfloat16",
|
26 |
+
"transformers_version": "4.56.0",
|
27 |
+
"use_cache": true,
|
28 |
+
"use_sliding_window": false,
|
29 |
+
"vocab_size": 151936
|
30 |
+
}
|
custom_generate/beam_constraints.py
ADDED
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
|
5 |
+
class Constraint(ABC):
|
6 |
+
r"""Abstract base class for all constraints that can be applied during generation.
|
7 |
+
It must define how the constraint can be satisfied.
|
8 |
+
|
9 |
+
All classes that inherit Constraint must follow the requirement that
|
10 |
+
|
11 |
+
```py
|
12 |
+
completed = False
|
13 |
+
while not completed:
|
14 |
+
_, completed = constraint.update(constraint.advance())
|
15 |
+
```
|
16 |
+
|
17 |
+
will always terminate (halt).
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self):
|
21 |
+
# test for the above condition
|
22 |
+
self.test()
|
23 |
+
|
24 |
+
def test(self):
|
25 |
+
"""
|
26 |
+
Tests whether this constraint has been properly defined.
|
27 |
+
"""
|
28 |
+
counter = 0
|
29 |
+
completed = False
|
30 |
+
while not completed:
|
31 |
+
if counter == 1:
|
32 |
+
self.reset()
|
33 |
+
advance = self.advance()
|
34 |
+
if not self.does_advance(advance):
|
35 |
+
raise Exception(
|
36 |
+
"Custom Constraint is not defined correctly. self.does_advance(self.advance()) must be true."
|
37 |
+
)
|
38 |
+
|
39 |
+
stepped, completed, reset = self.update(advance)
|
40 |
+
counter += 1
|
41 |
+
|
42 |
+
if counter > 10000:
|
43 |
+
raise Exception("update() does not fulfill the constraint.")
|
44 |
+
|
45 |
+
if self.remaining() != 0:
|
46 |
+
raise Exception("Custom Constraint is not defined correctly.")
|
47 |
+
|
48 |
+
@abstractmethod
|
49 |
+
def advance(self):
|
50 |
+
"""
|
51 |
+
When called, returns the token(s) that would take this constraint one step closer to being fulfilled.
|
52 |
+
|
53 |
+
Return:
|
54 |
+
token_ids (Union[int, list[int], None]):
|
55 |
+
- A single token ID (int) that advances the constraint, or
|
56 |
+
- A list of token IDs that could advance the constraint
|
57 |
+
- None if the constraint is completed or cannot be advanced
|
58 |
+
"""
|
59 |
+
raise NotImplementedError(
|
60 |
+
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
61 |
+
)
|
62 |
+
|
63 |
+
@abstractmethod
|
64 |
+
def does_advance(self, token_id: int):
|
65 |
+
"""
|
66 |
+
Reads in a token and returns whether it creates progress.
|
67 |
+
"""
|
68 |
+
raise NotImplementedError(
|
69 |
+
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
70 |
+
)
|
71 |
+
|
72 |
+
@abstractmethod
|
73 |
+
def update(self, token_id: int):
|
74 |
+
"""
|
75 |
+
Reads in a token and returns booleans that indicate the progress made by it. This function will update the
|
76 |
+
state of this object unlikes `does_advance(self, token_id: int)`.
|
77 |
+
|
78 |
+
This isn't to test whether a certain token will advance the progress; it's to update its state as if it has
|
79 |
+
been generated. This becomes important if token_id != desired token (refer to else statement in
|
80 |
+
PhrasalConstraint)
|
81 |
+
|
82 |
+
Args:
|
83 |
+
token_id(`int`):
|
84 |
+
The id of a newly generated token in the beam search.
|
85 |
+
Return:
|
86 |
+
stepped(`bool`):
|
87 |
+
Whether this constraint has become one step closer to being fulfuilled.
|
88 |
+
completed(`bool`):
|
89 |
+
Whether this constraint has been completely fulfilled by this token being generated.
|
90 |
+
reset (`bool`):
|
91 |
+
Whether this constraint has reset its progress by this token being generated.
|
92 |
+
"""
|
93 |
+
raise NotImplementedError(
|
94 |
+
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
95 |
+
)
|
96 |
+
|
97 |
+
@abstractmethod
|
98 |
+
def reset(self):
|
99 |
+
"""
|
100 |
+
Resets the state of this constraint to its initialization. We would call this in cases where the fulfillment of
|
101 |
+
a constraint is abrupted by an unwanted token.
|
102 |
+
"""
|
103 |
+
raise NotImplementedError(
|
104 |
+
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
105 |
+
)
|
106 |
+
|
107 |
+
@abstractmethod
|
108 |
+
def remaining(self):
|
109 |
+
"""
|
110 |
+
Returns the number of remaining steps of `advance()` in order to complete this constraint.
|
111 |
+
"""
|
112 |
+
raise NotImplementedError(
|
113 |
+
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
114 |
+
)
|
115 |
+
|
116 |
+
@abstractmethod
|
117 |
+
def copy(self, stateful=False):
|
118 |
+
"""
|
119 |
+
Creates a new instance of this constraint.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
stateful(`bool`): Whether to not only copy the constraint for new instance, but also its state.
|
123 |
+
|
124 |
+
Return:
|
125 |
+
constraint(`Constraint`): The same constraint as the one being called from.
|
126 |
+
"""
|
127 |
+
raise NotImplementedError(
|
128 |
+
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
129 |
+
)
|
130 |
+
|
131 |
+
|
132 |
+
class PhrasalConstraint(Constraint):
|
133 |
+
r"""
|
134 |
+
[`Constraint`] enforcing that an ordered sequence of tokens is included in the output.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
token_ids (`list[int]`):
|
138 |
+
The id of the token that must be generated by the output.
|
139 |
+
"""
|
140 |
+
|
141 |
+
def __init__(self, token_ids: list[int]):
|
142 |
+
super(Constraint, self).__init__()
|
143 |
+
|
144 |
+
if not isinstance(token_ids, list) or len(token_ids) == 0:
|
145 |
+
raise ValueError(f"`token_ids` has to be a non-empty list, but is {token_ids}.")
|
146 |
+
if any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids):
|
147 |
+
raise ValueError(f"Each list in `token_ids` has to be a list of positive integers, but is {token_ids}.")
|
148 |
+
|
149 |
+
self.token_ids = token_ids
|
150 |
+
|
151 |
+
self.seqlen = len(self.token_ids)
|
152 |
+
self.fulfilled_idx = -1 # the index of the currently fulfilled step
|
153 |
+
self.completed = False
|
154 |
+
|
155 |
+
def advance(self):
|
156 |
+
if self.completed:
|
157 |
+
return None
|
158 |
+
return self.token_ids[self.fulfilled_idx + 1]
|
159 |
+
|
160 |
+
def does_advance(self, token_id: int):
|
161 |
+
if not isinstance(token_id, int):
|
162 |
+
raise TypeError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}")
|
163 |
+
|
164 |
+
if self.completed:
|
165 |
+
return False
|
166 |
+
|
167 |
+
return token_id == self.token_ids[self.fulfilled_idx + 1]
|
168 |
+
|
169 |
+
def update(self, token_id: int):
|
170 |
+
if not isinstance(token_id, int):
|
171 |
+
raise TypeError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}")
|
172 |
+
|
173 |
+
stepped = False
|
174 |
+
completed = False
|
175 |
+
reset = False
|
176 |
+
|
177 |
+
if self.does_advance(token_id):
|
178 |
+
self.fulfilled_idx += 1
|
179 |
+
stepped = True
|
180 |
+
if self.fulfilled_idx == (self.seqlen - 1):
|
181 |
+
completed = True
|
182 |
+
self.completed = completed
|
183 |
+
else:
|
184 |
+
# failed to make progress.
|
185 |
+
reset = True
|
186 |
+
self.reset()
|
187 |
+
return stepped, completed, reset
|
188 |
+
|
189 |
+
def reset(self):
|
190 |
+
self.completed = False
|
191 |
+
self.fulfilled_idx = 0
|
192 |
+
|
193 |
+
def remaining(self):
|
194 |
+
return self.seqlen - (self.fulfilled_idx + 1)
|
195 |
+
|
196 |
+
def copy(self, stateful=False):
|
197 |
+
new_constraint = PhrasalConstraint(self.token_ids)
|
198 |
+
|
199 |
+
if stateful:
|
200 |
+
new_constraint.seq_len = self.seqlen
|
201 |
+
new_constraint.fulfilled_idx = self.fulfilled_idx
|
202 |
+
new_constraint.completed = self.completed
|
203 |
+
|
204 |
+
return new_constraint
|
205 |
+
|
206 |
+
|
207 |
+
class DisjunctiveTrie:
|
208 |
+
def __init__(self, nested_token_ids: list[list[int]], no_subsets=True):
|
209 |
+
r"""
|
210 |
+
A helper class that builds a trie with the words represented in `nested_token_ids`.
|
211 |
+
"""
|
212 |
+
self.max_height = max([len(one) for one in nested_token_ids])
|
213 |
+
|
214 |
+
root = {}
|
215 |
+
for token_ids in nested_token_ids:
|
216 |
+
level = root
|
217 |
+
for tidx, token_id in enumerate(token_ids):
|
218 |
+
if token_id not in level:
|
219 |
+
level[token_id] = {}
|
220 |
+
|
221 |
+
level = level[token_id]
|
222 |
+
|
223 |
+
if no_subsets and self.has_subsets(root, nested_token_ids):
|
224 |
+
raise ValueError(
|
225 |
+
"Each list in `nested_token_ids` can't be a complete subset of another list, but is"
|
226 |
+
f" {nested_token_ids}."
|
227 |
+
)
|
228 |
+
|
229 |
+
self.trie = root
|
230 |
+
|
231 |
+
def next_tokens(self, current_seq):
|
232 |
+
"""
|
233 |
+
The next possible tokens that will progress the trie, given the current sequence of tokens in `current_seq`.
|
234 |
+
"""
|
235 |
+
start = self.trie
|
236 |
+
|
237 |
+
for current_token in current_seq:
|
238 |
+
start = start[current_token]
|
239 |
+
|
240 |
+
next_tokens = list(start.keys())
|
241 |
+
|
242 |
+
return next_tokens
|
243 |
+
|
244 |
+
def reached_leaf(self, current_seq):
|
245 |
+
next_tokens = self.next_tokens(current_seq)
|
246 |
+
|
247 |
+
return len(next_tokens) == 0
|
248 |
+
|
249 |
+
def count_leaves(self, root):
|
250 |
+
next_nodes = list(root.values())
|
251 |
+
if len(next_nodes) == 0:
|
252 |
+
return 1
|
253 |
+
else:
|
254 |
+
return sum([self.count_leaves(nn) for nn in next_nodes])
|
255 |
+
|
256 |
+
def has_subsets(self, trie, nested_token_ids):
|
257 |
+
"""
|
258 |
+
Returns whether # of leaves == # of words. Otherwise some word is a subset of another.
|
259 |
+
"""
|
260 |
+
leaf_count = self.count_leaves(trie)
|
261 |
+
return len(nested_token_ids) != leaf_count
|
262 |
+
|
263 |
+
|
264 |
+
class DisjunctiveConstraint(Constraint):
|
265 |
+
r"""
|
266 |
+
A special [`Constraint`] that is fulfilled by fulfilling just one of several constraints.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
nested_token_ids (`list[list[int]]`):
|
270 |
+
A list of words, where each word is a list of ids. This constraint is fulfilled by generating just one from
|
271 |
+
the list of words.
|
272 |
+
"""
|
273 |
+
|
274 |
+
def __init__(self, nested_token_ids: list[list[int]]):
|
275 |
+
super(Constraint, self).__init__()
|
276 |
+
|
277 |
+
if not isinstance(nested_token_ids, list) or len(nested_token_ids) == 0:
|
278 |
+
raise ValueError(f"`nested_token_ids` has to be a non-empty list, but is {nested_token_ids}.")
|
279 |
+
if any(not isinstance(token_ids, list) for token_ids in nested_token_ids):
|
280 |
+
raise ValueError(f"`nested_token_ids` has to be a list of lists, but is {nested_token_ids}.")
|
281 |
+
if any(
|
282 |
+
any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
|
283 |
+
for token_ids in nested_token_ids
|
284 |
+
):
|
285 |
+
raise ValueError(
|
286 |
+
f"Each list in `nested_token_ids` has to be a list of positive integers, but is {nested_token_ids}."
|
287 |
+
)
|
288 |
+
|
289 |
+
self.trie = DisjunctiveTrie(nested_token_ids)
|
290 |
+
self.token_ids = nested_token_ids
|
291 |
+
|
292 |
+
self.seqlen = self.trie.max_height
|
293 |
+
self.current_seq = []
|
294 |
+
self.completed = False
|
295 |
+
|
296 |
+
def advance(self):
|
297 |
+
token_list = self.trie.next_tokens(self.current_seq)
|
298 |
+
|
299 |
+
if len(token_list) == 0:
|
300 |
+
return None
|
301 |
+
else:
|
302 |
+
return token_list
|
303 |
+
|
304 |
+
def does_advance(self, token_id: int):
|
305 |
+
if not isinstance(token_id, int):
|
306 |
+
raise TypeError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")
|
307 |
+
|
308 |
+
next_tokens = self.trie.next_tokens(self.current_seq)
|
309 |
+
|
310 |
+
return token_id in next_tokens
|
311 |
+
|
312 |
+
def update(self, token_id: int):
|
313 |
+
if not isinstance(token_id, int):
|
314 |
+
raise TypeError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")
|
315 |
+
|
316 |
+
stepped = False
|
317 |
+
completed = False
|
318 |
+
reset = False
|
319 |
+
|
320 |
+
if self.does_advance(token_id):
|
321 |
+
self.current_seq.append(token_id)
|
322 |
+
stepped = True
|
323 |
+
else:
|
324 |
+
reset = True
|
325 |
+
self.reset()
|
326 |
+
|
327 |
+
completed = self.trie.reached_leaf(self.current_seq)
|
328 |
+
self.completed = completed
|
329 |
+
|
330 |
+
return stepped, completed, reset
|
331 |
+
|
332 |
+
def reset(self):
|
333 |
+
self.completed = False
|
334 |
+
self.current_seq = []
|
335 |
+
|
336 |
+
def remaining(self):
|
337 |
+
if self.completed:
|
338 |
+
# since this can be completed without reaching max height
|
339 |
+
return 0
|
340 |
+
else:
|
341 |
+
return self.seqlen - len(self.current_seq)
|
342 |
+
|
343 |
+
def copy(self, stateful=False):
|
344 |
+
new_constraint = DisjunctiveConstraint(self.token_ids)
|
345 |
+
|
346 |
+
if stateful:
|
347 |
+
new_constraint.seq_len = self.seqlen
|
348 |
+
new_constraint.current_seq = self.current_seq
|
349 |
+
new_constraint.completed = self.completed
|
350 |
+
|
351 |
+
return new_constraint
|
352 |
+
|
353 |
+
|
354 |
+
class ConstraintListState:
|
355 |
+
r"""
|
356 |
+
A class for beam scorers to track its progress through a list of constraints.
|
357 |
+
|
358 |
+
Args:
|
359 |
+
constraints (`list[Constraint]`):
|
360 |
+
A list of [`Constraint`] objects that must be fulfilled by the beam scorer.
|
361 |
+
"""
|
362 |
+
|
363 |
+
def __init__(self, constraints: list[Constraint]):
|
364 |
+
self.constraints = constraints
|
365 |
+
|
366 |
+
# max # of steps required to fulfill a given constraint
|
367 |
+
self.max_seqlen = max([c.seqlen for c in constraints])
|
368 |
+
self.n_constraints = len(constraints)
|
369 |
+
self.completed = False
|
370 |
+
|
371 |
+
self.init_state()
|
372 |
+
|
373 |
+
def init_state(self):
|
374 |
+
self.complete_constraints = []
|
375 |
+
self.inprogress_constraint = None
|
376 |
+
self.pending_constraints = [constraint.copy(stateful=False) for constraint in self.constraints]
|
377 |
+
|
378 |
+
def get_bank(self):
|
379 |
+
add = 0
|
380 |
+
if self.inprogress_constraint:
|
381 |
+
# extra points for having a constraint mid-fulfilled
|
382 |
+
add += self.max_seqlen - self.inprogress_constraint.remaining()
|
383 |
+
|
384 |
+
return (len(self.complete_constraints) * self.max_seqlen) + add
|
385 |
+
|
386 |
+
def advance(self):
|
387 |
+
"""The list of tokens to generate such that we can make progress.
|
388 |
+
By "list" we don't mean the list of token that will fully fulfill a constraint.
|
389 |
+
|
390 |
+
Given constraints `c_i = {t_ij | j == # of tokens}`, If we're not in the middle of progressing through a
|
391 |
+
specific constraint `c_i`, we return:
|
392 |
+
|
393 |
+
`[t_k1 for k in indices of unfulfilled constraints]`
|
394 |
+
|
395 |
+
If we are in the middle of a constraint, then we return:
|
396 |
+
`[t_ij]`, where `i` is the index of the inprogress constraint, `j` is the next step for the constraint.
|
397 |
+
|
398 |
+
Though we don't care which constraint is fulfilled first, if we are in the progress of fulfilling a constraint,
|
399 |
+
that's the only one we'll return.
|
400 |
+
"""
|
401 |
+
token_list = []
|
402 |
+
if self.inprogress_constraint is None:
|
403 |
+
for constraint in self.pending_constraints: # "pending" == "unfulfilled yet"
|
404 |
+
advance = constraint.advance()
|
405 |
+
if isinstance(advance, int):
|
406 |
+
token_list.append(advance)
|
407 |
+
elif isinstance(advance, list):
|
408 |
+
token_list.extend(advance)
|
409 |
+
else:
|
410 |
+
advance = self.inprogress_constraint.advance()
|
411 |
+
if isinstance(advance, int):
|
412 |
+
token_list.append(advance)
|
413 |
+
elif isinstance(advance, list):
|
414 |
+
token_list.extend(advance)
|
415 |
+
|
416 |
+
if len(token_list) == 0:
|
417 |
+
return None
|
418 |
+
else:
|
419 |
+
return token_list
|
420 |
+
|
421 |
+
def reset(self, token_ids: Optional[list[int]]):
|
422 |
+
"""
|
423 |
+
token_ids: the tokens generated thus far to reset the state of the progress through constraints.
|
424 |
+
"""
|
425 |
+
self.init_state()
|
426 |
+
|
427 |
+
if token_ids is not None:
|
428 |
+
for token in token_ids:
|
429 |
+
# completes or steps **one** constraint
|
430 |
+
complete, stepped = self.add(token)
|
431 |
+
|
432 |
+
# the entire list of constraints are fulfilled
|
433 |
+
if self.completed:
|
434 |
+
break
|
435 |
+
|
436 |
+
def add(self, token_id: int):
|
437 |
+
if not isinstance(token_id, int):
|
438 |
+
raise TypeError(f"`token_id` should be an `int`, but is `{token_id}`.")
|
439 |
+
|
440 |
+
complete, stepped = False, False
|
441 |
+
|
442 |
+
if self.completed:
|
443 |
+
complete = True
|
444 |
+
stepped = False
|
445 |
+
return complete, stepped
|
446 |
+
|
447 |
+
if self.inprogress_constraint is not None:
|
448 |
+
# In the middle of fulfilling a constraint. If the `token_id` *does* makes an incremental progress to current
|
449 |
+
# job, simply update the state
|
450 |
+
|
451 |
+
stepped, complete, reset = self.inprogress_constraint.update(token_id)
|
452 |
+
if reset:
|
453 |
+
# 1. If the next token breaks the progress, then we must restart.
|
454 |
+
# e.g. constraint = "I love pies" and sequence so far is "I love" but `token_id` == "books".
|
455 |
+
|
456 |
+
# But that doesn't mean we self.init_state(), since we only reset the state for this particular
|
457 |
+
# constraint, not the full list of constraints.
|
458 |
+
|
459 |
+
self.pending_constraints.append(self.inprogress_constraint.copy(stateful=False))
|
460 |
+
self.inprogress_constraint = None
|
461 |
+
|
462 |
+
if complete:
|
463 |
+
# 2. If the next token completes the constraint, move it to completed list, set
|
464 |
+
# inprogress to None. If there are no pending constraints either, then this full list of constraints
|
465 |
+
# is complete.
|
466 |
+
|
467 |
+
self.complete_constraints.append(self.inprogress_constraint)
|
468 |
+
self.inprogress_constraint = None
|
469 |
+
|
470 |
+
if len(self.pending_constraints) == 0:
|
471 |
+
# we're done!
|
472 |
+
self.completed = True
|
473 |
+
|
474 |
+
else:
|
475 |
+
# Not in the middle of fulfilling a constraint. So does this `token_id` helps us step towards any of our list
|
476 |
+
# of constraints?
|
477 |
+
|
478 |
+
for cidx, pending_constraint in enumerate(self.pending_constraints):
|
479 |
+
if pending_constraint.does_advance(token_id):
|
480 |
+
stepped, complete, reset = pending_constraint.update(token_id)
|
481 |
+
|
482 |
+
if not stepped:
|
483 |
+
raise Exception(
|
484 |
+
"`constraint.update(token_id)` is not yielding incremental progress, "
|
485 |
+
"even though `constraint.does_advance(token_id)` is true."
|
486 |
+
)
|
487 |
+
|
488 |
+
if complete:
|
489 |
+
self.complete_constraints.append(pending_constraint)
|
490 |
+
self.inprogress_constraint = None
|
491 |
+
|
492 |
+
if not complete and stepped:
|
493 |
+
self.inprogress_constraint = pending_constraint
|
494 |
+
|
495 |
+
if complete or stepped:
|
496 |
+
# If we made any progress at all, then it's at least not a "pending constraint".
|
497 |
+
|
498 |
+
self.pending_constraints = (
|
499 |
+
self.pending_constraints[:cidx] + self.pending_constraints[cidx + 1 :]
|
500 |
+
)
|
501 |
+
|
502 |
+
if len(self.pending_constraints) == 0 and self.inprogress_constraint is None:
|
503 |
+
# If there's no longer any pending after this and no inprogress either, then we must be
|
504 |
+
# complete.
|
505 |
+
|
506 |
+
self.completed = True
|
507 |
+
|
508 |
+
break # prevent accidentally stepping through multiple constraints with just one token.
|
509 |
+
|
510 |
+
return complete, stepped
|
511 |
+
|
512 |
+
def copy(self, stateful=True):
|
513 |
+
new_state = ConstraintListState(self.constraints) # we actually never though self.constraints objects
|
514 |
+
# throughout this process. So it's at initialization state.
|
515 |
+
|
516 |
+
if stateful:
|
517 |
+
new_state.complete_constraints = [
|
518 |
+
constraint.copy(stateful=True) for constraint in self.complete_constraints
|
519 |
+
]
|
520 |
+
if self.inprogress_constraint is not None:
|
521 |
+
new_state.inprogress_constraint = self.inprogress_constraint.copy(stateful=True)
|
522 |
+
new_state.pending_constraints = [constraint.copy() for constraint in self.pending_constraints]
|
523 |
+
|
524 |
+
return new_state
|
custom_generate/beam_search.py
ADDED
@@ -0,0 +1,716 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The HuggingFace Inc. team
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from abc import ABC, abstractmethod
|
17 |
+
from collections import UserDict
|
18 |
+
from typing import Optional, Union
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
|
23 |
+
from transformers.utils import add_start_docstrings
|
24 |
+
from .beam_constraints import Constraint, ConstraintListState
|
25 |
+
|
26 |
+
|
27 |
+
PROCESS_INPUTS_DOCSTRING = r"""
|
28 |
+
Args:
|
29 |
+
input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
|
30 |
+
Indices of input sequence tokens in the vocabulary.
|
31 |
+
|
32 |
+
Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
|
33 |
+
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
|
34 |
+
|
35 |
+
[What are input IDs?](../glossary#input-ids)
|
36 |
+
next_scores (`torch.FloatTensor` of shape `(batch_size, 2 * num_beams)`):
|
37 |
+
Current scores of the top `2 * num_beams` non-finished beam hypotheses.
|
38 |
+
next_tokens (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
|
39 |
+
`input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses.
|
40 |
+
next_indices (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
|
41 |
+
Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
|
42 |
+
pad_token_id (`int`, *optional*):
|
43 |
+
The id of the *padding* token.
|
44 |
+
eos_token_id (`Union[int, list[int]]`, *optional*):
|
45 |
+
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
46 |
+
beam_indices (`torch.LongTensor`, *optional*):
|
47 |
+
Beam indices indicating to which beam hypothesis each token correspond.
|
48 |
+
group_index (`int`, *optional*):
|
49 |
+
The index of the group of beams. Used with [`~PreTrainedModel.group_beam_search`].
|
50 |
+
|
51 |
+
Return:
|
52 |
+
`UserDict`: A dictionary composed of the fields as defined above:
|
53 |
+
|
54 |
+
- **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of all
|
55 |
+
non-finished beams.
|
56 |
+
- **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be added
|
57 |
+
to the non-finished beam_hypotheses.
|
58 |
+
- **next_beam_indices** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Beam indices
|
59 |
+
indicating to which beam the next tokens shall be added.
|
60 |
+
|
61 |
+
"""
|
62 |
+
|
63 |
+
FINALIZE_INPUTS_DOCSTRING = r"""
|
64 |
+
Args:
|
65 |
+
input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
|
66 |
+
Indices of input sequence tokens in the vocabulary.
|
67 |
+
|
68 |
+
Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
|
69 |
+
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
|
70 |
+
|
71 |
+
[What are input IDs?](../glossary#input-ids)
|
72 |
+
final_beam_scores (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
|
73 |
+
The final scores of all non-finished beams.
|
74 |
+
final_beam_tokens (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
|
75 |
+
The last tokens to be added to the non-finished beam_hypotheses.
|
76 |
+
final_beam_indices (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
|
77 |
+
The beam indices indicating to which beam the `final_beam_tokens` shall be added.
|
78 |
+
pad_token_id (`int`, *optional*):
|
79 |
+
The id of the *padding* token.
|
80 |
+
eos_token_id (`Union[int, list[int]]`, *optional*):
|
81 |
+
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
82 |
+
|
83 |
+
Return:
|
84 |
+
`torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences.
|
85 |
+
The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early
|
86 |
+
due to the `eos_token_id`.
|
87 |
+
|
88 |
+
"""
|
89 |
+
|
90 |
+
|
91 |
+
class BeamScorer(ABC):
|
92 |
+
"""
|
93 |
+
Abstract base class for all beam scorers that are used for [`~PreTrainedModel.beam_search`] and
|
94 |
+
[`~PreTrainedModel.beam_sample`].
|
95 |
+
"""
|
96 |
+
|
97 |
+
@abstractmethod
|
98 |
+
@add_start_docstrings(PROCESS_INPUTS_DOCSTRING)
|
99 |
+
def process(
|
100 |
+
self,
|
101 |
+
input_ids: torch.LongTensor,
|
102 |
+
next_scores: torch.FloatTensor,
|
103 |
+
next_tokens: torch.LongTensor,
|
104 |
+
next_indices: torch.LongTensor,
|
105 |
+
**kwargs,
|
106 |
+
) -> tuple[torch.Tensor]:
|
107 |
+
raise NotImplementedError("This is an abstract method.")
|
108 |
+
|
109 |
+
@abstractmethod
|
110 |
+
@add_start_docstrings(FINALIZE_INPUTS_DOCSTRING)
|
111 |
+
def finalize(
|
112 |
+
self,
|
113 |
+
input_ids: torch.LongTensor,
|
114 |
+
next_scores: torch.FloatTensor,
|
115 |
+
next_tokens: torch.LongTensor,
|
116 |
+
next_indices: torch.LongTensor,
|
117 |
+
max_length: int,
|
118 |
+
**kwargs,
|
119 |
+
) -> torch.LongTensor:
|
120 |
+
raise NotImplementedError("This is an abstract method.")
|
121 |
+
|
122 |
+
class ConstrainedBeamSearchScorer(BeamScorer):
|
123 |
+
r"""
|
124 |
+
[`BeamScorer`] implementing constrained beam search decoding.
|
125 |
+
|
126 |
+
|
127 |
+
Args:
|
128 |
+
batch_size (`int`):
|
129 |
+
Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
|
130 |
+
num_beams (`int`):
|
131 |
+
Number of beams for beam search.
|
132 |
+
constraints (`list[Constraint]`):
|
133 |
+
A list of positive constraints represented as `Constraint` objects that must be fulfilled in the generation
|
134 |
+
output. For more information, the documentation of [`Constraint`] should be read.
|
135 |
+
device (`torch.device`):
|
136 |
+
Defines the device type (*e.g.*, `"cpu"` or `"cuda"`) on which this instance of `BeamSearchScorer` will be
|
137 |
+
allocated.
|
138 |
+
length_penalty (`float`, *optional*, defaults to 1.0):
|
139 |
+
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
|
140 |
+
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
|
141 |
+
likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
|
142 |
+
`length_penalty` < 0.0 encourages shorter sequences.
|
143 |
+
do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
|
144 |
+
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
|
145 |
+
`True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
|
146 |
+
heuristic is applied and the generation stops when is it very unlikely to find better candidates;
|
147 |
+
`"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
|
148 |
+
beam search algorithm).
|
149 |
+
num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
|
150 |
+
The number of beam hypotheses that shall be returned upon calling
|
151 |
+
[`~transformers.BeamSearchScorer.finalize`].
|
152 |
+
num_beam_groups (`int`, *optional*, defaults to 1):
|
153 |
+
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
|
154 |
+
See [this paper](https://huggingface.co/papers/1610.02424) for more details.
|
155 |
+
max_length (`int`, *optional*):
|
156 |
+
The maximum length of the sequence to be generated.
|
157 |
+
"""
|
158 |
+
|
159 |
+
def __init__(
|
160 |
+
self,
|
161 |
+
batch_size: int,
|
162 |
+
num_beams: int,
|
163 |
+
constraints: list[Constraint],
|
164 |
+
device: torch.device,
|
165 |
+
length_penalty: Optional[float] = 1.0,
|
166 |
+
do_early_stopping: Optional[Union[bool, str]] = False,
|
167 |
+
num_beam_hyps_to_keep: Optional[int] = 1,
|
168 |
+
num_beam_groups: Optional[int] = 1,
|
169 |
+
max_length: Optional[int] = None,
|
170 |
+
):
|
171 |
+
self.num_beams = num_beams
|
172 |
+
self.device = device
|
173 |
+
self.length_penalty = length_penalty
|
174 |
+
self.do_early_stopping = do_early_stopping
|
175 |
+
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
|
176 |
+
self.num_beam_groups = num_beam_groups
|
177 |
+
self.group_size = self.num_beams // self.num_beam_groups
|
178 |
+
self.constraints = constraints
|
179 |
+
|
180 |
+
self._is_init = False
|
181 |
+
self._beam_hyps = [
|
182 |
+
BeamHypotheses(
|
183 |
+
num_beams=self.num_beams,
|
184 |
+
length_penalty=self.length_penalty,
|
185 |
+
early_stopping=self.do_early_stopping,
|
186 |
+
max_length=max_length,
|
187 |
+
)
|
188 |
+
for _ in range(batch_size)
|
189 |
+
]
|
190 |
+
self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
|
191 |
+
|
192 |
+
if not isinstance(num_beams, int) or num_beams <= 1:
|
193 |
+
raise ValueError(
|
194 |
+
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
|
195 |
+
" one should make use of `greedy_search` instead."
|
196 |
+
)
|
197 |
+
|
198 |
+
if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
|
199 |
+
raise ValueError(
|
200 |
+
"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
|
201 |
+
f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
|
202 |
+
)
|
203 |
+
|
204 |
+
@property
|
205 |
+
def is_done(self) -> bool:
|
206 |
+
return self._done.all()
|
207 |
+
|
208 |
+
def make_constraint_states(self, n):
|
209 |
+
return [ConstraintListState([constraint.copy() for constraint in self.constraints]) for _ in range(n)]
|
210 |
+
|
211 |
+
def check_completes_constraints(self, sequence):
|
212 |
+
new_state = self.make_constraint_states(1)[0]
|
213 |
+
new_state.reset(sequence)
|
214 |
+
return new_state.completed
|
215 |
+
|
216 |
+
def process(
|
217 |
+
self,
|
218 |
+
input_ids: torch.LongTensor,
|
219 |
+
next_scores: torch.FloatTensor,
|
220 |
+
next_tokens: torch.LongTensor,
|
221 |
+
next_indices: torch.LongTensor,
|
222 |
+
scores_for_all_vocab: torch.FloatTensor,
|
223 |
+
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
|
224 |
+
eos_token_id: Optional[Union[int, list[int], torch.Tensor]] = None,
|
225 |
+
beam_indices: Optional[torch.LongTensor] = None,
|
226 |
+
decoder_prompt_len: Optional[int] = 0,
|
227 |
+
) -> tuple[torch.Tensor]:
|
228 |
+
r"""
|
229 |
+
Args:
|
230 |
+
input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
|
231 |
+
Indices of input sequence tokens in the vocabulary.
|
232 |
+
|
233 |
+
Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
|
234 |
+
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
|
235 |
+
|
236 |
+
[What are input IDs?](../glossary#input-ids)
|
237 |
+
next_scores (`torch.FloatTensor` of shape `(batch_size, 2 * num_beams)`):
|
238 |
+
Current scores of the top `2 * num_beams` non-finished beam hypotheses.
|
239 |
+
next_tokens (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
|
240 |
+
`input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses.
|
241 |
+
next_indices (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
|
242 |
+
Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
|
243 |
+
scores_for_all_vocab (`torch.FloatTensor` of shape `(batch_size * num_beams, sequence_length)`):
|
244 |
+
The scores of all tokens in the vocabulary for each of the beam hypotheses.
|
245 |
+
pad_token_id (`int`, *optional*):
|
246 |
+
The id of the *padding* token.
|
247 |
+
eos_token_id (`Union[int, list[int]]`, *optional*):
|
248 |
+
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
249 |
+
beam_indices (`torch.LongTensor`, *optional*):
|
250 |
+
Beam indices indicating to which beam hypothesis each token correspond.
|
251 |
+
decoder_prompt_len (`int`, *optional*):
|
252 |
+
The length of prompt that is included in the input to decoder.
|
253 |
+
Return:
|
254 |
+
`UserDict`: A dictionary composed of the fields as defined above:
|
255 |
+
|
256 |
+
- **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of
|
257 |
+
all
|
258 |
+
non-finished beams.
|
259 |
+
|
260 |
+
- **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be
|
261 |
+
added
|
262 |
+
to the non-finished beam_hypotheses.
|
263 |
+
- **next_beam_indices** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Beam indices
|
264 |
+
indicating to which beam the next tokens shall be added.
|
265 |
+
"""
|
266 |
+
|
267 |
+
# add up to the length which the next_scores is calculated on (including decoder prompt)
|
268 |
+
cur_len = input_ids.shape[-1] + 1
|
269 |
+
batch_size = len(self._beam_hyps)
|
270 |
+
if batch_size != (input_ids.shape[0] // self.group_size):
|
271 |
+
if self.num_beam_groups > 1:
|
272 |
+
raise ValueError(
|
273 |
+
f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
|
274 |
+
f"size of {self.group_size} is expected by the beam scorer."
|
275 |
+
)
|
276 |
+
else:
|
277 |
+
raise ValueError(
|
278 |
+
f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
|
279 |
+
f"{self.group_size} is expected by the beam scorer."
|
280 |
+
)
|
281 |
+
|
282 |
+
device = input_ids.device
|
283 |
+
|
284 |
+
next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
|
285 |
+
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
|
286 |
+
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
|
287 |
+
|
288 |
+
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
|
289 |
+
if isinstance(eos_token_id, int):
|
290 |
+
eos_token_id = [eos_token_id]
|
291 |
+
eos_token_id = torch.tensor(eos_token_id)
|
292 |
+
|
293 |
+
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
294 |
+
if self._done[batch_idx]:
|
295 |
+
if self.num_beams < len(beam_hyp):
|
296 |
+
raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
|
297 |
+
if eos_token_id is None or pad_token_id is None:
|
298 |
+
raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
|
299 |
+
# pad the batch
|
300 |
+
next_beam_scores[batch_idx, :] = 0
|
301 |
+
next_beam_tokens[batch_idx, :] = pad_token_id
|
302 |
+
next_beam_indices[batch_idx, :] = 0
|
303 |
+
continue
|
304 |
+
|
305 |
+
# next tokens for this sentence.
|
306 |
+
beam_idx = 0
|
307 |
+
for beam_token_rank, (next_token, next_score, next_index) in enumerate(
|
308 |
+
zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
|
309 |
+
):
|
310 |
+
batch_beam_idx = batch_idx * self.group_size + next_index
|
311 |
+
# add to generated hypotheses if end of sentence
|
312 |
+
if (eos_token_id is not None) and (next_token.item() in eos_token_id):
|
313 |
+
# if beam_token does not belong to top num_beams tokens, it should not be added
|
314 |
+
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
|
315 |
+
if is_beam_token_worse_than_top_num_beams:
|
316 |
+
continue
|
317 |
+
|
318 |
+
completes_constraint = self.check_completes_constraints(input_ids[batch_beam_idx].tolist())
|
319 |
+
if completes_constraint:
|
320 |
+
if beam_indices is not None:
|
321 |
+
beam_index = beam_indices[batch_beam_idx]
|
322 |
+
beam_index = beam_index + (batch_beam_idx,)
|
323 |
+
else:
|
324 |
+
beam_index = None
|
325 |
+
|
326 |
+
beam_hyp.add(
|
327 |
+
input_ids[batch_beam_idx].clone(),
|
328 |
+
next_score.item(),
|
329 |
+
beam_indices=beam_index,
|
330 |
+
generated_len=cur_len - decoder_prompt_len,
|
331 |
+
)
|
332 |
+
else:
|
333 |
+
# add next predicted token since it is not eos_token
|
334 |
+
next_beam_scores[batch_idx, beam_idx] = next_score
|
335 |
+
next_beam_tokens[batch_idx, beam_idx] = next_token
|
336 |
+
next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
|
337 |
+
beam_idx += 1
|
338 |
+
|
339 |
+
# once the beam for next step is full, don't add more tokens to it.
|
340 |
+
if beam_idx == self.group_size:
|
341 |
+
break
|
342 |
+
|
343 |
+
new_scores, new_tokens, new_indices = self.step_sentence_constraint(
|
344 |
+
batch_idx,
|
345 |
+
input_ids,
|
346 |
+
scores_for_all_vocab,
|
347 |
+
next_beam_scores[batch_idx],
|
348 |
+
next_beam_tokens[batch_idx],
|
349 |
+
next_beam_indices[batch_idx],
|
350 |
+
)
|
351 |
+
|
352 |
+
next_beam_scores[batch_idx] = new_scores
|
353 |
+
next_beam_tokens[batch_idx] = new_tokens
|
354 |
+
next_beam_indices[batch_idx] = new_indices
|
355 |
+
|
356 |
+
if beam_idx < self.group_size:
|
357 |
+
raise ValueError(
|
358 |
+
f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
|
359 |
+
f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
|
360 |
+
)
|
361 |
+
|
362 |
+
# Check if we are done so that we can save a pad step if all(done)
|
363 |
+
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
|
364 |
+
next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
|
365 |
+
)
|
366 |
+
|
367 |
+
return UserDict(
|
368 |
+
{
|
369 |
+
"next_beam_scores": next_beam_scores.view(-1),
|
370 |
+
"next_beam_tokens": next_beam_tokens.view(-1),
|
371 |
+
"next_beam_indices": next_beam_indices.view(-1),
|
372 |
+
}
|
373 |
+
)
|
374 |
+
|
375 |
+
def step_sentence_constraint(
|
376 |
+
self,
|
377 |
+
batch_idx: int,
|
378 |
+
input_ids: torch.LongTensor,
|
379 |
+
vocab_scores: torch.FloatTensor,
|
380 |
+
sent_beam_scores: torch.FloatTensor,
|
381 |
+
sent_beam_tokens: torch.LongTensor,
|
382 |
+
sent_beam_indices: torch.LongTensor,
|
383 |
+
push_progress: bool = False,
|
384 |
+
):
|
385 |
+
# sent_beam_tokens are the next {num_beams} number of tokens that are under consideration for this beam
|
386 |
+
# (candidate next tokens)
|
387 |
+
|
388 |
+
# 1. Adding "advance_tokens"
|
389 |
+
# using ConstraintStateList.advance(), we propose new tokens to be added into this "candidate list" that will
|
390 |
+
# advance us in fulfilling the constraints.
|
391 |
+
|
392 |
+
# 2. Selecting best candidates such that we end up with highest probable candidates
|
393 |
+
# that fulfill our constraints.
|
394 |
+
|
395 |
+
orig_len = sent_beam_indices.size(0)
|
396 |
+
device = sent_beam_indices.device
|
397 |
+
|
398 |
+
# initialize states
|
399 |
+
topk_contraint_states = self.make_constraint_states(orig_len)
|
400 |
+
advance_constraint_states = self.make_constraint_states(orig_len)
|
401 |
+
|
402 |
+
sidx, eidx = batch_idx * orig_len, (batch_idx + 1) * orig_len
|
403 |
+
this_batch_input_ids = input_ids[sidx:eidx]
|
404 |
+
this_batch_token_scores = vocab_scores[sidx:eidx]
|
405 |
+
full_hypotheses = torch.cat((input_ids[sent_beam_indices], sent_beam_tokens.unsqueeze(-1)), dim=-1)
|
406 |
+
|
407 |
+
# need to make new hypothesis that advance the constraints
|
408 |
+
track_new = {
|
409 |
+
"new_seqs": full_hypotheses.tolist(),
|
410 |
+
"new_states": [],
|
411 |
+
"new_indices": [],
|
412 |
+
"new_tokens": [],
|
413 |
+
"new_scores": [],
|
414 |
+
}
|
415 |
+
for seq_idx, pre_seq in enumerate(this_batch_input_ids):
|
416 |
+
# pre_seq = ith sequence generated before this step.
|
417 |
+
|
418 |
+
# input_ids -> (topk) generic beam search best model next tokens
|
419 |
+
# -> (advance) constraints forcing the next token
|
420 |
+
# either way, we need to sort them into "banks" later, so store a "ConstraintListState" for all types of
|
421 |
+
# hypotheses.
|
422 |
+
|
423 |
+
topk_state = topk_contraint_states[seq_idx]
|
424 |
+
topk_state.reset(full_hypotheses[seq_idx].tolist())
|
425 |
+
|
426 |
+
advance_state = advance_constraint_states[seq_idx]
|
427 |
+
advance_state.reset(pre_seq.tolist())
|
428 |
+
|
429 |
+
if not advance_state.completed:
|
430 |
+
advance_tokens = torch.tensor(advance_state.advance(), dtype=torch.long, device=device)
|
431 |
+
for advance_token in advance_tokens:
|
432 |
+
# since adding each `advance_token` leads to a different hypothesis, create new state instance.
|
433 |
+
new_state = advance_state.copy(stateful=True)
|
434 |
+
new_state.add(advance_token.tolist())
|
435 |
+
|
436 |
+
advance_seq = torch.cat((pre_seq, advance_token.unsqueeze(0)), -1).tolist()
|
437 |
+
if advance_seq not in track_new["new_seqs"]:
|
438 |
+
# prevent duplicates, which are basically bound to happen in this process.
|
439 |
+
track_new["new_seqs"].append(advance_seq)
|
440 |
+
track_new["new_indices"].append(sidx + seq_idx) # idx -> global idx across all the batches
|
441 |
+
track_new["new_tokens"].append(advance_token)
|
442 |
+
track_new["new_scores"].append(this_batch_token_scores[seq_idx].take(advance_token))
|
443 |
+
track_new["new_states"].append(new_state)
|
444 |
+
elif push_progress:
|
445 |
+
# Basically, `sent_beam_indices` often chooses very little among `input_ids` the generated sequences that
|
446 |
+
# actually fulfill our constraints. For example, let constraints == ["loves pies"] and
|
447 |
+
|
448 |
+
# pre_seq_1 = "The child loves pies and" pre_seq_2 = "The child plays in the playground and"
|
449 |
+
|
450 |
+
# Without this step, if `sent_beam_indices` is something like [1,1], then
|
451 |
+
# 1. `pre_seq_1` won't be added to the list of (topk) hypothesis since it's not in the indices and
|
452 |
+
# 2. it won't be added to the list of (advance) hypothesis since it's completed already. (this is
|
453 |
+
# the else part of `if constraints_completed[seq_idx]`)
|
454 |
+
# 3. it ends up simply getting removed from consideration.
|
455 |
+
|
456 |
+
# #3 might be fine and actually desired, since it's likely that it's a low-probability output anyways,
|
457 |
+
# especially if it's not in the list of `sent_beam_indices`. But this often leads to lengthened beam
|
458 |
+
# search times, since completed sequences keep getting removed after all this effort for constrained
|
459 |
+
# generation.
|
460 |
+
|
461 |
+
# Here, we basically take `pre_seq_1` and to "push" it into the considered list of hypotheses, by simply
|
462 |
+
# appending the next likely token in the vocabulary and adding it to the list of hypotheses.
|
463 |
+
|
464 |
+
new_score, new_token = torch.max(this_batch_token_scores[seq_idx], 0) # some next probable token
|
465 |
+
advance_seq = torch.cat((pre_seq, new_token.unsqueeze(0)), -1)
|
466 |
+
|
467 |
+
advance_state = advance_constraint_states[seq_idx]
|
468 |
+
|
469 |
+
advance_seq = advance_seq.tolist()
|
470 |
+
|
471 |
+
advance_state.reset(advance_seq)
|
472 |
+
if advance_seq not in track_new["new_seqs"]:
|
473 |
+
# but still don't want to have duplicates
|
474 |
+
track_new["new_seqs"].append(advance_seq)
|
475 |
+
track_new["new_indices"].append(seq_idx)
|
476 |
+
track_new["new_tokens"].append(new_token)
|
477 |
+
track_new["new_scores"].append(new_score)
|
478 |
+
track_new["new_states"].append(advance_state)
|
479 |
+
|
480 |
+
if len(track_new["new_indices"]) > 0:
|
481 |
+
new_indices = torch.tensor(track_new["new_indices"], device=device)
|
482 |
+
new_tokens = torch.stack(track_new["new_tokens"]).to(device)
|
483 |
+
new_scores = torch.stack(track_new["new_scores"]).to(device)
|
484 |
+
|
485 |
+
all_states = topk_contraint_states + track_new["new_states"]
|
486 |
+
all_tokens = torch.cat((sent_beam_tokens, new_tokens), -1)
|
487 |
+
all_scores = torch.cat((sent_beam_scores, new_scores), -1)
|
488 |
+
all_banks = torch.tensor([one.get_bank() for one in all_states], device=device)
|
489 |
+
|
490 |
+
zipped = all_banks * 100 + all_scores
|
491 |
+
indices = zipped.sort(descending=True).indices
|
492 |
+
sorted_banks = all_banks[indices]
|
493 |
+
|
494 |
+
# Then we end up with {sorted among bank C}, {sorted among bank C-1}, ..., {sorted among bank 0}
|
495 |
+
|
496 |
+
counter = -1
|
497 |
+
cur_bank = sorted_banks[0]
|
498 |
+
increments = []
|
499 |
+
for bank in sorted_banks:
|
500 |
+
if bank == cur_bank:
|
501 |
+
counter += 1
|
502 |
+
else:
|
503 |
+
counter = 0
|
504 |
+
cur_bank = bank
|
505 |
+
increments.append(counter)
|
506 |
+
rearrangers = torch.tensor(np.argsort(increments, kind="mergesort"))
|
507 |
+
|
508 |
+
indices = indices[rearrangers][:orig_len]
|
509 |
+
|
510 |
+
sent_beam_scores = all_scores[indices]
|
511 |
+
sent_beam_tokens = all_tokens[indices]
|
512 |
+
sent_beam_indices = torch.cat((sent_beam_indices, new_indices))[indices]
|
513 |
+
|
514 |
+
return sent_beam_scores, sent_beam_tokens, sent_beam_indices
|
515 |
+
|
516 |
+
def finalize(
|
517 |
+
self,
|
518 |
+
input_ids: torch.LongTensor,
|
519 |
+
final_beam_scores: torch.FloatTensor,
|
520 |
+
final_beam_tokens: torch.LongTensor,
|
521 |
+
final_beam_indices: torch.LongTensor,
|
522 |
+
max_length: int,
|
523 |
+
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
|
524 |
+
eos_token_id: Optional[Union[int, list[int], torch.Tensor]] = None,
|
525 |
+
beam_indices: Optional[torch.LongTensor] = None,
|
526 |
+
decoder_prompt_len: Optional[int] = 0,
|
527 |
+
) -> tuple[torch.LongTensor]:
|
528 |
+
batch_size = len(self._beam_hyps)
|
529 |
+
|
530 |
+
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
|
531 |
+
if isinstance(eos_token_id, int):
|
532 |
+
eos_token_id = [eos_token_id]
|
533 |
+
eos_token_id = torch.tensor(eos_token_id)
|
534 |
+
|
535 |
+
# finalize all open beam hypotheses and add to generated hypotheses
|
536 |
+
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
537 |
+
if self._done[batch_idx]:
|
538 |
+
continue
|
539 |
+
|
540 |
+
# all open beam hypotheses are added to the beam hypothesis
|
541 |
+
# beam hypothesis class automatically keeps the best beams
|
542 |
+
|
543 |
+
ids_collect = []
|
544 |
+
for beam_id in range(self.num_beams):
|
545 |
+
batch_beam_idx = batch_idx * self.num_beams + beam_id
|
546 |
+
final_score = final_beam_scores[batch_beam_idx].item()
|
547 |
+
final_tokens = input_ids[batch_beam_idx]
|
548 |
+
|
549 |
+
completes_constraint = self.check_completes_constraints(final_tokens.tolist())
|
550 |
+
if completes_constraint:
|
551 |
+
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
|
552 |
+
generated_len = final_tokens.shape[-1] - decoder_prompt_len
|
553 |
+
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
|
554 |
+
ids_collect.append(beam_id)
|
555 |
+
|
556 |
+
# due to overly complex constraints or other factors, sometimes we can't guarantee a successful
|
557 |
+
# generation. In these cases we simply return the highest scoring outputs.
|
558 |
+
if len(ids_collect) < self.num_beam_hyps_to_keep:
|
559 |
+
for beam_id in range(self.num_beams):
|
560 |
+
if beam_id not in ids_collect:
|
561 |
+
batch_beam_idx = batch_idx * self.num_beams + beam_id
|
562 |
+
final_score = final_beam_scores[batch_beam_idx].item()
|
563 |
+
final_tokens = input_ids[batch_beam_idx]
|
564 |
+
generated_len = final_tokens.shape[-1] - decoder_prompt_len
|
565 |
+
beam_hyp.add(final_tokens, final_score, generated_len=generated_len)
|
566 |
+
if len(ids_collect) >= self.num_beam_hyps_to_keep:
|
567 |
+
break
|
568 |
+
|
569 |
+
# select the best hypotheses
|
570 |
+
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
|
571 |
+
best = []
|
572 |
+
best_indices = []
|
573 |
+
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
|
574 |
+
|
575 |
+
# retrieve best hypotheses
|
576 |
+
for i, beam_hyp in enumerate(self._beam_hyps):
|
577 |
+
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
|
578 |
+
for j in range(self.num_beam_hyps_to_keep):
|
579 |
+
best_hyp_tuple = sorted_hyps.pop()
|
580 |
+
best_score = best_hyp_tuple[0]
|
581 |
+
best_hyp = best_hyp_tuple[1]
|
582 |
+
best_index = best_hyp_tuple[2]
|
583 |
+
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
|
584 |
+
|
585 |
+
# append to lists
|
586 |
+
best.append(best_hyp)
|
587 |
+
|
588 |
+
# append indices to list
|
589 |
+
best_indices.append(best_index)
|
590 |
+
|
591 |
+
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
|
592 |
+
|
593 |
+
# prepare for adding eos
|
594 |
+
sent_lengths_max = sent_lengths.max().item() + 1
|
595 |
+
|
596 |
+
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
|
597 |
+
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
598 |
+
|
599 |
+
if len(best_indices) > 0 and best_indices[0] is not None:
|
600 |
+
indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
601 |
+
else:
|
602 |
+
indices = None
|
603 |
+
|
604 |
+
# shorter batches are padded if needed
|
605 |
+
if sent_lengths.min().item() != sent_lengths.max().item():
|
606 |
+
if pad_token_id is None:
|
607 |
+
raise ValueError("`pad_token_id` has to be defined")
|
608 |
+
decoded.fill_(pad_token_id)
|
609 |
+
|
610 |
+
if indices is not None:
|
611 |
+
indices.fill_(-1)
|
612 |
+
|
613 |
+
# fill with hypotheses and eos_token_id if the latter fits in
|
614 |
+
for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
|
615 |
+
decoded[i, : sent_lengths[i]] = hypo
|
616 |
+
|
617 |
+
if indices is not None:
|
618 |
+
indices[i, : len(best_idx)] = torch.tensor(best_idx)
|
619 |
+
|
620 |
+
if sent_lengths[i] < sent_max_len:
|
621 |
+
# inserting only the first eos_token_id
|
622 |
+
decoded[i, sent_lengths[i]] = eos_token_id[0]
|
623 |
+
|
624 |
+
return UserDict(
|
625 |
+
{
|
626 |
+
"sequences": decoded,
|
627 |
+
"sequence_scores": best_scores,
|
628 |
+
"beam_indices": indices,
|
629 |
+
}
|
630 |
+
)
|
631 |
+
|
632 |
+
|
633 |
+
class BeamHypotheses:
|
634 |
+
def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool, max_length: Optional[int] = None):
|
635 |
+
"""
|
636 |
+
Initialize n-best list of hypotheses.
|
637 |
+
"""
|
638 |
+
self.length_penalty = length_penalty
|
639 |
+
self.early_stopping = early_stopping
|
640 |
+
self.max_length = max_length
|
641 |
+
self.num_beams = num_beams
|
642 |
+
self.beams = []
|
643 |
+
self.worst_score = 1e9
|
644 |
+
|
645 |
+
if not isinstance(self.early_stopping, bool) and self.max_length is None:
|
646 |
+
raise ValueError(
|
647 |
+
"When `do_early_stopping` is set to a string, `max_length` must be defined. Ensure it is passed to the"
|
648 |
+
" BeamScorer class instance at initialization time."
|
649 |
+
)
|
650 |
+
|
651 |
+
def __len__(self):
|
652 |
+
"""
|
653 |
+
Number of hypotheses in the list.
|
654 |
+
"""
|
655 |
+
return len(self.beams)
|
656 |
+
|
657 |
+
def add(
|
658 |
+
self,
|
659 |
+
hyp: torch.LongTensor,
|
660 |
+
sum_logprobs: float,
|
661 |
+
beam_indices: Optional[torch.LongTensor] = None,
|
662 |
+
generated_len: Optional[int] = None,
|
663 |
+
):
|
664 |
+
"""
|
665 |
+
Add a new hypothesis to the list.
|
666 |
+
"""
|
667 |
+
if generated_len is not None:
|
668 |
+
score = sum_logprobs / (generated_len**self.length_penalty)
|
669 |
+
# This 'else' case exists for retrocompatibility
|
670 |
+
else:
|
671 |
+
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
|
672 |
+
|
673 |
+
if len(self) < self.num_beams or score > self.worst_score:
|
674 |
+
self.beams.append((score, hyp, beam_indices))
|
675 |
+
if len(self) > self.num_beams:
|
676 |
+
sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
|
677 |
+
del self.beams[sorted_next_scores[0][1]]
|
678 |
+
self.worst_score = sorted_next_scores[1][0]
|
679 |
+
else:
|
680 |
+
self.worst_score = min(score, self.worst_score)
|
681 |
+
|
682 |
+
def is_done(self, best_sum_logprobs: float, cur_len: int, decoder_prompt_len: Optional[int] = 0) -> bool:
|
683 |
+
"""
|
684 |
+
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
|
685 |
+
one in the heap, then we are done with this sentence.
|
686 |
+
"""
|
687 |
+
|
688 |
+
if len(self) < self.num_beams:
|
689 |
+
return False
|
690 |
+
|
691 |
+
# `True`: stop as soon as at least `num_beams` hypotheses are finished
|
692 |
+
if self.early_stopping is True:
|
693 |
+
return True
|
694 |
+
# `False`: heuristic -- compute best possible score from `cur_len`, even though it is not entirely accurate
|
695 |
+
# when `length_penalty` is positive. See the discussion below for more details.
|
696 |
+
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
|
697 |
+
elif self.early_stopping is False:
|
698 |
+
highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
|
699 |
+
ret = self.worst_score >= highest_attainable_score
|
700 |
+
return ret
|
701 |
+
# `"never"`: compute the best possible score, depending on the signal of `length_penalty`
|
702 |
+
else:
|
703 |
+
# `length_penalty` > 0.0 -> max denominator is obtaned from `max_length`, not from `cur_len` -> min
|
704 |
+
# abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain
|
705 |
+
# its max this way
|
706 |
+
if self.length_penalty > 0.0:
|
707 |
+
if self.max_length <= decoder_prompt_len:
|
708 |
+
raise ValueError("max_length is not larger than decoder prompt length")
|
709 |
+
highest_attainable_score = (
|
710 |
+
best_sum_logprobs / (self.max_length - decoder_prompt_len) ** self.length_penalty
|
711 |
+
)
|
712 |
+
# the opposite logic applies here (max `highest_attainable_score` from `cur_len`)
|
713 |
+
else:
|
714 |
+
highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
|
715 |
+
ret = self.worst_score >= highest_attainable_score
|
716 |
+
return ret
|
custom_generate/generate.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
import torch
|
3 |
+
from transformers import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
|
4 |
+
from transformers.generation.utils import (
|
5 |
+
GenerateBeamOutput,
|
6 |
+
GenerationMixin,
|
7 |
+
GenerateBeamDecoderOnlyOutput,
|
8 |
+
GenerateBeamEncoderDecoderOutput,
|
9 |
+
)
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import numpy as np
|
13 |
+
import logging
|
14 |
+
|
15 |
+
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
|
16 |
+
from .beam_search import ConstrainedBeamSearchScorer
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
def _constrained_beam_search(
|
21 |
+
model,
|
22 |
+
input_ids: torch.LongTensor,
|
23 |
+
logits_processor: LogitsProcessorList,
|
24 |
+
stopping_criteria: StoppingCriteriaList,
|
25 |
+
generation_config: GenerationConfig,
|
26 |
+
synced_gpus: bool,
|
27 |
+
**model_kwargs,
|
28 |
+
) -> Union[GenerateBeamOutput, torch.LongTensor]:
|
29 |
+
r"""
|
30 |
+
Generates sequences of token ids for models with a language modeling head using **constrained beam search
|
31 |
+
decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
32 |
+
|
33 |
+
Parameters:
|
34 |
+
input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`):
|
35 |
+
The sequence used as a prompt for the generation.
|
36 |
+
logits_processor (`LogitsProcessorList`):
|
37 |
+
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
38 |
+
used to modify the prediction scores of the language modeling head applied at each generation step.
|
39 |
+
stopping_criteria (`StoppingCriteriaList`):
|
40 |
+
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
41 |
+
used to tell if the generation loop should stop.
|
42 |
+
generation_config ([`~generation.GenerationConfig`]):
|
43 |
+
The generation configuration to be used as parametrization of the decoding method.
|
44 |
+
synced_gpus (`bool`):
|
45 |
+
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
46 |
+
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
47 |
+
model_kwargs:
|
48 |
+
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
49 |
+
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
50 |
+
|
51 |
+
Return:
|
52 |
+
[`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
|
53 |
+
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
|
54 |
+
[`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
|
55 |
+
`return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
|
56 |
+
`model.config.is_encoder_decoder=True`.
|
57 |
+
"""
|
58 |
+
if generation_config.constraints is not None or generation_config.force_words_ids is not None:
|
59 |
+
constrained_wrong_parameter_msg = (
|
60 |
+
"one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. "
|
61 |
+
"However, `{flag_name}` is set to `{flag_value}`, which is incompatible with this generation "
|
62 |
+
"mode. Set `constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue."
|
63 |
+
)
|
64 |
+
if generation_config.do_sample is True:
|
65 |
+
raise ValueError(
|
66 |
+
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=generation_config.do_sample)
|
67 |
+
)
|
68 |
+
|
69 |
+
final_constraints = []
|
70 |
+
if generation_config.constraints is not None:
|
71 |
+
final_constraints = generation_config.constraints
|
72 |
+
|
73 |
+
if generation_config.force_words_ids is not None:
|
74 |
+
|
75 |
+
def typeerror():
|
76 |
+
raise ValueError(
|
77 |
+
"`force_words_ids` has to either be a `list[list[list[int]]]` or `list[list[int]]` "
|
78 |
+
f"of positive integers, but is {generation_config.force_words_ids}."
|
79 |
+
)
|
80 |
+
|
81 |
+
if (
|
82 |
+
not isinstance(generation_config.force_words_ids, list)
|
83 |
+
or len(generation_config.force_words_ids) == 0
|
84 |
+
):
|
85 |
+
typeerror()
|
86 |
+
|
87 |
+
for word_ids in generation_config.force_words_ids:
|
88 |
+
if isinstance(word_ids[0], list):
|
89 |
+
if not isinstance(word_ids, list) or len(word_ids) == 0:
|
90 |
+
typeerror()
|
91 |
+
if any(not isinstance(token_ids, list) for token_ids in word_ids):
|
92 |
+
typeerror()
|
93 |
+
if any(
|
94 |
+
any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
|
95 |
+
for token_ids in word_ids
|
96 |
+
):
|
97 |
+
typeerror()
|
98 |
+
|
99 |
+
constraint = DisjunctiveConstraint(word_ids)
|
100 |
+
else:
|
101 |
+
if not isinstance(word_ids, list) or len(word_ids) == 0:
|
102 |
+
typeerror()
|
103 |
+
if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):
|
104 |
+
typeerror()
|
105 |
+
|
106 |
+
constraint = PhrasalConstraint(word_ids)
|
107 |
+
final_constraints.append(constraint)
|
108 |
+
# define beam scorer
|
109 |
+
constrained_beam_scorer = ConstrainedBeamSearchScorer(
|
110 |
+
constraints=final_constraints,
|
111 |
+
batch_size=batch_size,
|
112 |
+
num_beams=generation_config.num_beams,
|
113 |
+
device=inputs_tensor.device,
|
114 |
+
length_penalty=generation_config.length_penalty,
|
115 |
+
do_early_stopping=generation_config.early_stopping,
|
116 |
+
num_beam_hyps_to_keep=generation_config.num_return_sequences,
|
117 |
+
max_length=generation_config.max_length,
|
118 |
+
)
|
119 |
+
# init values
|
120 |
+
pad_token_id = generation_config._pad_token_tensor
|
121 |
+
eos_token_id = generation_config._eos_token_tensor
|
122 |
+
output_attentions = generation_config.output_attentions
|
123 |
+
output_hidden_states = generation_config.output_hidden_states
|
124 |
+
output_scores = generation_config.output_scores
|
125 |
+
output_logits = generation_config.output_logits
|
126 |
+
return_dict_in_generate = generation_config.return_dict_in_generate
|
127 |
+
|
128 |
+
batch_size = len(constrained_beam_scorer._beam_hyps)
|
129 |
+
num_beams = constrained_beam_scorer.num_beams
|
130 |
+
|
131 |
+
batch_beam_size, cur_len = input_ids.shape[:2]
|
132 |
+
model_kwargs = model._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
|
133 |
+
|
134 |
+
if num_beams * batch_size != batch_beam_size:
|
135 |
+
raise ValueError(
|
136 |
+
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
137 |
+
)
|
138 |
+
|
139 |
+
# init attention / hidden states / scores tuples
|
140 |
+
scores = () if (return_dict_in_generate and output_scores) else None
|
141 |
+
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
142 |
+
beam_indices = (
|
143 |
+
tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
|
144 |
+
)
|
145 |
+
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
146 |
+
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
147 |
+
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
148 |
+
|
149 |
+
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
150 |
+
if return_dict_in_generate and model.config.is_encoder_decoder:
|
151 |
+
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
152 |
+
encoder_hidden_states = (
|
153 |
+
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
154 |
+
)
|
155 |
+
|
156 |
+
# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
|
157 |
+
# of the first beam are considered to avoid sampling the exact same tokens across all beams.
|
158 |
+
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
159 |
+
beam_scores[:, 1:] = -1e9
|
160 |
+
beam_scores = beam_scores.view((batch_size * num_beams,))
|
161 |
+
|
162 |
+
this_peer_finished = False
|
163 |
+
|
164 |
+
decoder_prompt_len = input_ids.shape[1] # record the prompt length of decoder
|
165 |
+
while model._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
166 |
+
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
167 |
+
|
168 |
+
# prepare variable output controls (note: some models won't accept all output controls)
|
169 |
+
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
|
170 |
+
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
|
171 |
+
|
172 |
+
outputs = model(**model_inputs, return_dict=True)
|
173 |
+
|
174 |
+
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
175 |
+
model_kwargs = model._update_model_kwargs_for_generation(
|
176 |
+
outputs,
|
177 |
+
model_kwargs,
|
178 |
+
is_encoder_decoder=model.config.is_encoder_decoder,
|
179 |
+
)
|
180 |
+
if synced_gpus and this_peer_finished:
|
181 |
+
cur_len = cur_len + 1
|
182 |
+
continue
|
183 |
+
|
184 |
+
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
185 |
+
# (the clone itself is always small)
|
186 |
+
# .float() is needed to retain precision for later logits manipulations
|
187 |
+
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
|
188 |
+
next_token_scores = nn.functional.log_softmax(
|
189 |
+
next_token_logits, dim=-1
|
190 |
+
) # (batch_size * num_beams, vocab_size)
|
191 |
+
|
192 |
+
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
|
193 |
+
|
194 |
+
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
|
195 |
+
next_token_scores_processed
|
196 |
+
)
|
197 |
+
|
198 |
+
scores_for_all_vocab = next_token_scores.clone()
|
199 |
+
|
200 |
+
# Store scores, attentions and hidden_states when required
|
201 |
+
if return_dict_in_generate:
|
202 |
+
if output_scores:
|
203 |
+
scores += (next_token_scores,)
|
204 |
+
if output_logits:
|
205 |
+
raw_logits += (next_token_logits,)
|
206 |
+
if output_attentions:
|
207 |
+
decoder_attentions += (
|
208 |
+
(outputs.decoder_attentions,) if model.config.is_encoder_decoder else (outputs.attentions,)
|
209 |
+
)
|
210 |
+
if model.config.is_encoder_decoder:
|
211 |
+
cross_attentions += (outputs.cross_attentions,)
|
212 |
+
|
213 |
+
if output_hidden_states:
|
214 |
+
decoder_hidden_states += (
|
215 |
+
(outputs.decoder_hidden_states,)
|
216 |
+
if model.config.is_encoder_decoder
|
217 |
+
else (outputs.hidden_states,)
|
218 |
+
)
|
219 |
+
|
220 |
+
# reshape for beam search
|
221 |
+
vocab_size = next_token_scores.shape[-1]
|
222 |
+
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
|
223 |
+
|
224 |
+
# Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
|
225 |
+
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
|
226 |
+
next_token_scores, next_tokens = torch.topk(
|
227 |
+
next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True
|
228 |
+
)
|
229 |
+
|
230 |
+
next_indices = (next_tokens / vocab_size).long()
|
231 |
+
next_tokens = next_tokens % vocab_size
|
232 |
+
|
233 |
+
# stateless
|
234 |
+
beam_outputs = constrained_beam_scorer.process(
|
235 |
+
input_ids,
|
236 |
+
next_token_scores,
|
237 |
+
next_tokens,
|
238 |
+
next_indices,
|
239 |
+
scores_for_all_vocab,
|
240 |
+
pad_token_id=pad_token_id,
|
241 |
+
eos_token_id=eos_token_id,
|
242 |
+
beam_indices=beam_indices,
|
243 |
+
decoder_prompt_len=decoder_prompt_len,
|
244 |
+
)
|
245 |
+
beam_scores = beam_outputs["next_beam_scores"]
|
246 |
+
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
247 |
+
beam_idx = beam_outputs["next_beam_indices"]
|
248 |
+
|
249 |
+
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
250 |
+
|
251 |
+
# This is needed to properly delete outputs.logits which may be very large for first iteration
|
252 |
+
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
|
253 |
+
# IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
|
254 |
+
# (that way the memory peak does not include outputs.logits)
|
255 |
+
del outputs
|
256 |
+
|
257 |
+
# NOTE: we need to check if `model._reorder_cache` exists for special models like RAG, RecurrentGemma etc.
|
258 |
+
if model_kwargs.get("past_key_values", None) is not None:
|
259 |
+
if hasattr(model, "_reorder_cache"):
|
260 |
+
model_kwargs["past_key_values"] = model._reorder_cache(model_kwargs["past_key_values"], beam_idx)
|
261 |
+
else:
|
262 |
+
model_kwargs["past_key_values"].reorder_cache(beam_idx)
|
263 |
+
|
264 |
+
if return_dict_in_generate and output_scores:
|
265 |
+
beam_indices = tuple(beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))
|
266 |
+
|
267 |
+
# increase cur_len
|
268 |
+
cur_len = cur_len + 1
|
269 |
+
|
270 |
+
if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
|
271 |
+
this_peer_finished = True
|
272 |
+
|
273 |
+
sequence_outputs = constrained_beam_scorer.finalize(
|
274 |
+
input_ids,
|
275 |
+
beam_scores,
|
276 |
+
next_tokens,
|
277 |
+
next_indices,
|
278 |
+
pad_token_id=pad_token_id,
|
279 |
+
eos_token_id=eos_token_id,
|
280 |
+
max_length=stopping_criteria.max_length,
|
281 |
+
beam_indices=beam_indices,
|
282 |
+
decoder_prompt_len=decoder_prompt_len,
|
283 |
+
)
|
284 |
+
|
285 |
+
if return_dict_in_generate:
|
286 |
+
if not output_scores:
|
287 |
+
sequence_outputs["sequence_scores"] = None
|
288 |
+
if model.config.is_encoder_decoder:
|
289 |
+
return GenerateBeamEncoderDecoderOutput(
|
290 |
+
sequences=sequence_outputs["sequences"],
|
291 |
+
sequences_scores=sequence_outputs["sequence_scores"],
|
292 |
+
scores=scores,
|
293 |
+
logits=raw_logits,
|
294 |
+
beam_indices=sequence_outputs["beam_indices"],
|
295 |
+
encoder_attentions=encoder_attentions,
|
296 |
+
encoder_hidden_states=encoder_hidden_states,
|
297 |
+
decoder_attentions=decoder_attentions,
|
298 |
+
cross_attentions=cross_attentions,
|
299 |
+
decoder_hidden_states=decoder_hidden_states,
|
300 |
+
past_key_values=model_kwargs.get("past_key_values"),
|
301 |
+
)
|
302 |
+
else:
|
303 |
+
return GenerateBeamDecoderOnlyOutput(
|
304 |
+
sequences=sequence_outputs["sequences"],
|
305 |
+
sequences_scores=sequence_outputs["sequence_scores"],
|
306 |
+
scores=scores,
|
307 |
+
logits=raw_logits,
|
308 |
+
beam_indices=sequence_outputs["beam_indices"],
|
309 |
+
attentions=decoder_attentions,
|
310 |
+
hidden_states=decoder_hidden_states,
|
311 |
+
past_key_values=model_kwargs.get("past_key_values"),
|
312 |
+
)
|
313 |
+
else:
|
314 |
+
return sequence_outputs["sequences"]
|
315 |
+
|
316 |
+
def generate(model, *args, **kwargs):
|
317 |
+
"""Custom generate function for constrained beam search decoding.
|
318 |
+
Args:
|
319 |
+
model (`PreTrainedModel`):
|
320 |
+
The model to generate from.
|
321 |
+
num_beams (`int`): The number of beams to use for beam search.
|
322 |
+
constraints (`list[Constraint]`, *optional*):
|
323 |
+
Custom constraints that can be added to the generation to ensure that the output will contain the use of
|
324 |
+
certain tokens as defined by `Constraint` objects, in the most sensible way possible.
|
325 |
+
force_words_ids (`list[list[list[int]]]`): List of token ids that must be generated. If given a `list[list[int]]`, this is treated as a simple list of
|
326 |
+
words that must be included, the opposite to `bad_words_ids`. If given `list[list[list[int]]]`, this
|
327 |
+
triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one
|
328 |
+
can allow different forms of each word.
|
329 |
+
length_penalty (`float`): The length penalty to use for beam search.
|
330 |
+
early_stopping (`bool`): Whether to stop beam search when sufficient beams have finished.
|
331 |
+
num_return_sequences (`int`): The number of sequences to return.
|
332 |
+
max_length (`int`): The maximum length of the generated sequence.
|
333 |
+
"""
|
334 |
+
generation_outputs = GenerationMixin.generate(
|
335 |
+
model, *args, custom_generate=_constrained_beam_search, **kwargs
|
336 |
+
)
|
337 |
+
return generation_outputs
|
generation_config.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token_id": 151643,
|
3 |
+
"do_sample": true,
|
4 |
+
"eos_token_id": [
|
5 |
+
151645,
|
6 |
+
151643
|
7 |
+
],
|
8 |
+
"pad_token_id": 151643,
|
9 |
+
"temperature": 0.6,
|
10 |
+
"top_k": 20,
|
11 |
+
"top_p": 0.95,
|
12 |
+
"transformers_version": "4.56.0"
|
13 |
+
}
|
merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f47f71177f32bcd101b7573ec9171e6a57f4f4d31148d38e382306f42996874b
|
3 |
+
size 1503300328
|
tokenizer.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aeb13307a71acd8fe81861d94ad54ab689df773318809eed3cbe794b4492dae4
|
3 |
+
size 11422654
|
tokenizer_config.json
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": false,
|
3 |
+
"add_prefix_space": false,
|
4 |
+
"added_tokens_decoder": {
|
5 |
+
"151643": {
|
6 |
+
"content": "<|endoftext|>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false,
|
11 |
+
"special": true
|
12 |
+
},
|
13 |
+
"151644": {
|
14 |
+
"content": "<|im_start|>",
|
15 |
+
"lstrip": false,
|
16 |
+
"normalized": false,
|
17 |
+
"rstrip": false,
|
18 |
+
"single_word": false,
|
19 |
+
"special": true
|
20 |
+
},
|
21 |
+
"151645": {
|
22 |
+
"content": "<|im_end|>",
|
23 |
+
"lstrip": false,
|
24 |
+
"normalized": false,
|
25 |
+
"rstrip": false,
|
26 |
+
"single_word": false,
|
27 |
+
"special": true
|
28 |
+
},
|
29 |
+
"151646": {
|
30 |
+
"content": "<|object_ref_start|>",
|
31 |
+
"lstrip": false,
|
32 |
+
"normalized": false,
|
33 |
+
"rstrip": false,
|
34 |
+
"single_word": false,
|
35 |
+
"special": true
|
36 |
+
},
|
37 |
+
"151647": {
|
38 |
+
"content": "<|object_ref_end|>",
|
39 |
+
"lstrip": false,
|
40 |
+
"normalized": false,
|
41 |
+
"rstrip": false,
|
42 |
+
"single_word": false,
|
43 |
+
"special": true
|
44 |
+
},
|
45 |
+
"151648": {
|
46 |
+
"content": "<|box_start|>",
|
47 |
+
"lstrip": false,
|
48 |
+
"normalized": false,
|
49 |
+
"rstrip": false,
|
50 |
+
"single_word": false,
|
51 |
+
"special": true
|
52 |
+
},
|
53 |
+
"151649": {
|
54 |
+
"content": "<|box_end|>",
|
55 |
+
"lstrip": false,
|
56 |
+
"normalized": false,
|
57 |
+
"rstrip": false,
|
58 |
+
"single_word": false,
|
59 |
+
"special": true
|
60 |
+
},
|
61 |
+
"151650": {
|
62 |
+
"content": "<|quad_start|>",
|
63 |
+
"lstrip": false,
|
64 |
+
"normalized": false,
|
65 |
+
"rstrip": false,
|
66 |
+
"single_word": false,
|
67 |
+
"special": true
|
68 |
+
},
|
69 |
+
"151651": {
|
70 |
+
"content": "<|quad_end|>",
|
71 |
+
"lstrip": false,
|
72 |
+
"normalized": false,
|
73 |
+
"rstrip": false,
|
74 |
+
"single_word": false,
|
75 |
+
"special": true
|
76 |
+
},
|
77 |
+
"151652": {
|
78 |
+
"content": "<|vision_start|>",
|
79 |
+
"lstrip": false,
|
80 |
+
"normalized": false,
|
81 |
+
"rstrip": false,
|
82 |
+
"single_word": false,
|
83 |
+
"special": true
|
84 |
+
},
|
85 |
+
"151653": {
|
86 |
+
"content": "<|vision_end|>",
|
87 |
+
"lstrip": false,
|
88 |
+
"normalized": false,
|
89 |
+
"rstrip": false,
|
90 |
+
"single_word": false,
|
91 |
+
"special": true
|
92 |
+
},
|
93 |
+
"151654": {
|
94 |
+
"content": "<|vision_pad|>",
|
95 |
+
"lstrip": false,
|
96 |
+
"normalized": false,
|
97 |
+
"rstrip": false,
|
98 |
+
"single_word": false,
|
99 |
+
"special": true
|
100 |
+
},
|
101 |
+
"151655": {
|
102 |
+
"content": "<|image_pad|>",
|
103 |
+
"lstrip": false,
|
104 |
+
"normalized": false,
|
105 |
+
"rstrip": false,
|
106 |
+
"single_word": false,
|
107 |
+
"special": true
|
108 |
+
},
|
109 |
+
"151656": {
|
110 |
+
"content": "<|video_pad|>",
|
111 |
+
"lstrip": false,
|
112 |
+
"normalized": false,
|
113 |
+
"rstrip": false,
|
114 |
+
"single_word": false,
|
115 |
+
"special": true
|
116 |
+
},
|
117 |
+
"151657": {
|
118 |
+
"content": "<tool_call>",
|
119 |
+
"lstrip": false,
|
120 |
+
"normalized": false,
|
121 |
+
"rstrip": false,
|
122 |
+
"single_word": false,
|
123 |
+
"special": false
|
124 |
+
},
|
125 |
+
"151658": {
|
126 |
+
"content": "</tool_call>",
|
127 |
+
"lstrip": false,
|
128 |
+
"normalized": false,
|
129 |
+
"rstrip": false,
|
130 |
+
"single_word": false,
|
131 |
+
"special": false
|
132 |
+
},
|
133 |
+
"151659": {
|
134 |
+
"content": "<|fim_prefix|>",
|
135 |
+
"lstrip": false,
|
136 |
+
"normalized": false,
|
137 |
+
"rstrip": false,
|
138 |
+
"single_word": false,
|
139 |
+
"special": false
|
140 |
+
},
|
141 |
+
"151660": {
|
142 |
+
"content": "<|fim_middle|>",
|
143 |
+
"lstrip": false,
|
144 |
+
"normalized": false,
|
145 |
+
"rstrip": false,
|
146 |
+
"single_word": false,
|
147 |
+
"special": false
|
148 |
+
},
|
149 |
+
"151661": {
|
150 |
+
"content": "<|fim_suffix|>",
|
151 |
+
"lstrip": false,
|
152 |
+
"normalized": false,
|
153 |
+
"rstrip": false,
|
154 |
+
"single_word": false,
|
155 |
+
"special": false
|
156 |
+
},
|
157 |
+
"151662": {
|
158 |
+
"content": "<|fim_pad|>",
|
159 |
+
"lstrip": false,
|
160 |
+
"normalized": false,
|
161 |
+
"rstrip": false,
|
162 |
+
"single_word": false,
|
163 |
+
"special": false
|
164 |
+
},
|
165 |
+
"151663": {
|
166 |
+
"content": "<|repo_name|>",
|
167 |
+
"lstrip": false,
|
168 |
+
"normalized": false,
|
169 |
+
"rstrip": false,
|
170 |
+
"single_word": false,
|
171 |
+
"special": false
|
172 |
+
},
|
173 |
+
"151664": {
|
174 |
+
"content": "<|file_sep|>",
|
175 |
+
"lstrip": false,
|
176 |
+
"normalized": false,
|
177 |
+
"rstrip": false,
|
178 |
+
"single_word": false,
|
179 |
+
"special": false
|
180 |
+
},
|
181 |
+
"151665": {
|
182 |
+
"content": "<tool_response>",
|
183 |
+
"lstrip": false,
|
184 |
+
"normalized": false,
|
185 |
+
"rstrip": false,
|
186 |
+
"single_word": false,
|
187 |
+
"special": false
|
188 |
+
},
|
189 |
+
"151666": {
|
190 |
+
"content": "</tool_response>",
|
191 |
+
"lstrip": false,
|
192 |
+
"normalized": false,
|
193 |
+
"rstrip": false,
|
194 |
+
"single_word": false,
|
195 |
+
"special": false
|
196 |
+
},
|
197 |
+
"151667": {
|
198 |
+
"content": "<think>",
|
199 |
+
"lstrip": false,
|
200 |
+
"normalized": false,
|
201 |
+
"rstrip": false,
|
202 |
+
"single_word": false,
|
203 |
+
"special": false
|
204 |
+
},
|
205 |
+
"151668": {
|
206 |
+
"content": "</think>",
|
207 |
+
"lstrip": false,
|
208 |
+
"normalized": false,
|
209 |
+
"rstrip": false,
|
210 |
+
"single_word": false,
|
211 |
+
"special": false
|
212 |
+
}
|
213 |
+
},
|
214 |
+
"additional_special_tokens": [
|
215 |
+
"<|im_start|>",
|
216 |
+
"<|im_end|>",
|
217 |
+
"<|object_ref_start|>",
|
218 |
+
"<|object_ref_end|>",
|
219 |
+
"<|box_start|>",
|
220 |
+
"<|box_end|>",
|
221 |
+
"<|quad_start|>",
|
222 |
+
"<|quad_end|>",
|
223 |
+
"<|vision_start|>",
|
224 |
+
"<|vision_end|>",
|
225 |
+
"<|vision_pad|>",
|
226 |
+
"<|image_pad|>",
|
227 |
+
"<|video_pad|>"
|
228 |
+
],
|
229 |
+
"bos_token": null,
|
230 |
+
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in content %}\n {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}",
|
231 |
+
"clean_up_tokenization_spaces": false,
|
232 |
+
"eos_token": "<|im_end|>",
|
233 |
+
"errors": "replace",
|
234 |
+
"model_max_length": 131072,
|
235 |
+
"pad_token": "<|endoftext|>",
|
236 |
+
"split_special_tokens": false,
|
237 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
238 |
+
"unk_token": null
|
239 |
+
}
|
vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|