TF generate refactor - Beam Search (#16374)
* refactor TF beam search * refactored generate can now properly use attention masks * add force bos/eos logit processors
This commit is contained in:
@@ -1179,7 +1179,7 @@ class TFModelTesterMixin:
|
||||
# num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(do_sample=True, max_length=5, num_beams=2))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertRaises(ValueError):
|
||||
# generating more sequences than having beams leads is not possible
|
||||
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user