TF generate refactor - Beam Search (#16374)

* refactor TF beam search

* refactored generate can now properly use attention masks

* add force bos/eos logit processors
This commit is contained in:
Joao Gante
2022-04-06 18:19:34 +01:00
committed by GitHub
parent 4d10083539
commit 3f43d824b9
11 changed files with 796 additions and 56 deletions

View File

@@ -472,14 +472,14 @@ class LogitsProcessorTest(unittest.TestCase):
logits_processor = ForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
# check that all scores are -inf except the eos_token_id when max_length is reached
# check that all scores are -inf except the eos_token_id when max_length-1 is reached
input_ids = ids_tensor((batch_size, 4), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores)
self.assertTrue(torch.isneginf(scores[:, eos_token_id + 1 :]).all())
self.assertListEqual(scores[:, eos_token_id].tolist(), 4 * [0]) # score for eos_token_id should be zero
# check that eos_token_id is not forced if max_length is not reached
# check that eos_token_id is not forced if max_length-1 is not reached
input_ids = ids_tensor((batch_size, 3), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores)