manueldeprada HF Staff commited on
Commit
fe87ab3
·
verified ·
1 Parent(s): a8edf3a

Upload folder using huggingface_hub

Browse files
.ruff_cache/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Automatically created by ruff.
2
+ *
.ruff_cache/0.12.8/7010951691598163845 ADDED
Binary file (149 Bytes). View file
 
.ruff_cache/CACHEDIR.TAG ADDED
@@ -0,0 +1 @@
 
 
1
+ Signature: 8a477f597d28d172789f06886806bc55
custom_generate/generate.py CHANGED
@@ -108,9 +108,9 @@ def _constrained_beam_search(
108
  # define beam scorer
109
  constrained_beam_scorer = ConstrainedBeamSearchScorer(
110
  constraints=final_constraints,
111
- batch_size=batch_size,
112
  num_beams=generation_config.num_beams,
113
- device=inputs_tensor.device,
114
  length_penalty=generation_config.length_penalty,
115
  do_early_stopping=generation_config.early_stopping,
116
  num_beam_hyps_to_keep=generation_config.num_return_sequences,
 
108
  # define beam scorer
109
  constrained_beam_scorer = ConstrainedBeamSearchScorer(
110
  constraints=final_constraints,
111
+ batch_size=input_ids.shape[0] // generation_config.num_beams,
112
  num_beams=generation_config.num_beams,
113
+ device=input_ids.device,
114
  length_penalty=generation_config.length_penalty,
115
  do_early_stopping=generation_config.early_stopping,
116
  num_beam_hyps_to_keep=generation_config.num_return_sequences,