Upload folder using huggingface_hub
Browse files
.ruff_cache/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Automatically created by ruff.
|
2 |
+
*
|
.ruff_cache/0.12.8/7010951691598163845
ADDED
Binary file (149 Bytes). View file
|
|
.ruff_cache/CACHEDIR.TAG
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Signature: 8a477f597d28d172789f06886806bc55
|
custom_generate/generate.py
CHANGED
@@ -108,9 +108,9 @@ def _constrained_beam_search(
|
|
108 |
# define beam scorer
|
109 |
constrained_beam_scorer = ConstrainedBeamSearchScorer(
|
110 |
constraints=final_constraints,
|
111 |
-
batch_size=
|
112 |
num_beams=generation_config.num_beams,
|
113 |
-
device=
|
114 |
length_penalty=generation_config.length_penalty,
|
115 |
do_early_stopping=generation_config.early_stopping,
|
116 |
num_beam_hyps_to_keep=generation_config.num_return_sequences,
|
|
|
108 |
# define beam scorer
|
109 |
constrained_beam_scorer = ConstrainedBeamSearchScorer(
|
110 |
constraints=final_constraints,
|
111 |
+
batch_size=input_ids.shape[0] // generation_config.num_beams,
|
112 |
num_beams=generation_config.num_beams,
|
113 |
+
device=input_ids.device,
|
114 |
length_penalty=generation_config.length_penalty,
|
115 |
do_early_stopping=generation_config.early_stopping,
|
116 |
num_beam_hyps_to_keep=generation_config.num_return_sequences,
|