TF generate refactor - Beam Search (#16374)

* refactor TF beam search

* refactored generate can now properly use attention masks

* add force bos/eos logit processors
This commit is contained in:
Joao Gante
2022-04-06 18:19:34 +01:00
committed by GitHub
parent 4d10083539
commit 3f43d824b9
11 changed files with 796 additions and 56 deletions

View File

@@ -1179,7 +1179,7 @@ class TFModelTesterMixin:
# num_return_sequences = 1
self._check_generated_ids(model.generate(do_sample=True, max_length=5, num_beams=2))
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
# generating more sequences than having beams leads is not possible
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)