added beam_search generation for tf 2.0
This commit is contained in:
committed by
Patrick von Platen
parent
34de670dbe
commit
61fef6e957
@@ -381,7 +381,6 @@ class TFModelTesterMixin:
|
||||
) # TODO (PVP): ugly workaround to make code work for t5 for the moment - has to changed when t5 is fixed.
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
# TODO (PVP): add beam search tests when beam search is implemented
|
||||
model = model_class(config)
|
||||
|
||||
if config.bos_token_id is None:
|
||||
@@ -389,15 +388,34 @@ class TFModelTesterMixin:
|
||||
model.generate(max_length=5)
|
||||
# batch_size = 1
|
||||
self._check_generated_tokens(model.generate(input_ids))
|
||||
# batch_size = 1, num_beams > 1
|
||||
self._check_generated_tokens(model.generate(input_ids, num_beams=3))
|
||||
else:
|
||||
# batch_size = 1
|
||||
self._check_generated_tokens(model.generate(max_length=5))
|
||||
# batch_size = 1, num_beams > 1
|
||||
self._check_generated_tokens(model.generate(max_length=5, num_beams=3))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
# generating multiple sequences when greedy no beam generation
|
||||
# is not allowed as it would always generate the same sequences
|
||||
model.generate(input_ids, do_sample=False, num_return_sequences=2)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
# generating more sequences than having beams leads is not possible
|
||||
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
|
||||
|
||||
# batch_size > 1, sample
|
||||
self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3))
|
||||
# batch_size > 1, greedy
|
||||
self._check_generated_tokens(model.generate(input_ids, do_sample=False, num_return_sequences=3))
|
||||
self._check_generated_tokens(model.generate(input_ids, do_sample=False))
|
||||
|
||||
# batch_size > 1, num_beams > 1, sample
|
||||
self._check_generated_tokens(model.generate(input_ids, num_beams=3, num_return_sequences=3,))
|
||||
# batch_size > 1, num_beams > 1, greedy
|
||||
self._check_generated_tokens(
|
||||
model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3)
|
||||
)
|
||||
|
||||
def _check_generated_tokens(self, output_ids):
|
||||
for token_id in output_ids[0].numpy().tolist():
|
||||
|
||||
Reference in New Issue
Block a user