From cb966e640b8b9d0f6e9c06c1655d078a917e5196 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 4 Nov 2020 15:44:36 +0100 Subject: [PATCH] [Generate Test] fix greedy generate test (#8293) * fix greedy generate test * delet ipdb --- tests/test_generation_utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index 0cdd80dd74..ab07987315 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -140,10 +140,6 @@ class GenerationTesterMixin: # check `generate()` and `greedy_search()` are equal kwargs = {} if model.config.is_encoder_decoder: - encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( - model, input_ids, attention_mask - ) - kwargs["encoder_outputs"] = encoder_outputs max_length = 4 output_ids_generate = model.generate( @@ -154,6 +150,13 @@ class GenerationTesterMixin: max_length=max_length, **logits_process_kwargs, ) + + if model.config.is_encoder_decoder: + encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( + model, input_ids, attention_mask + ) + kwargs["encoder_outputs"] = encoder_outputs + with torch.no_grad(): output_ids_greedy = model.greedy_search( input_ids,