correct greedy generation when doing beam search (#3078)

* correct greedy generation when doing beam search

* improve comment
This commit is contained in:
Patrick von Platen
2020-03-02 18:00:09 +01:00
committed by GitHub
parent 13afb71208
commit 2fdc7f6ce8
2 changed files with 41 additions and 7 deletions

View File

@@ -621,10 +621,19 @@ class ModelTesterMixin:
# batch_size = 1, num_beams > 1
self._check_generated_tokens(model.generate(max_length=5, num_beams=3))
with self.assertRaises(AssertionError):
# generating multiple sequences when greedy no beam generation
# is not allowed as it would always generate the same sequences
model.generate(input_ids, do_sample=False, num_return_sequences=2)
with self.assertRaises(AssertionError):
# generating more sequences than having beams leads is not possible
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
# batch_size > 1, sample
self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3))
# batch_size > 1, greedy
self._check_generated_tokens(model.generate(input_ids, do_sample=False, num_return_sequences=3))
self._check_generated_tokens(model.generate(input_ids, do_sample=False))
# batch_size > 1, num_beams > 1, sample
self._check_generated_tokens(model.generate(input_ids, num_beams=3, num_return_sequences=3,))
# batch_size > 1, num_beams > 1, greedy