From b3b9f99ed216ee5faa899f1047e43002c6a222c0 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Wed, 18 May 2022 17:57:23 +0200 Subject: [PATCH] Fix test_t5_decoder_model_past_large_inputs (#17320) Co-authored-by: ydshieh --- tests/models/t5/test_modeling_tf_t5.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/models/t5/test_modeling_tf_t5.py b/tests/models/t5/test_modeling_tf_t5.py index 91bc63feda..5ad746e34f 100644 --- a/tests/models/t5/test_modeling_tf_t5.py +++ b/tests/models/t5/test_modeling_tf_t5.py @@ -295,6 +295,13 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): def test_t5_decoder_model_past_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() + + # `create_and_check_t5_decoder_model_past_large_inputs` has special inputs: + # (config, input_ids, decoder_input_ids, attention_mask) + # and we have to prepare it correctly here. + config, input_ids, input_mask, token_labels = config_and_inputs + config_and_inputs = (config, input_ids, None, input_mask) + self.model_tester.create_and_check_t5_decoder_model_past_large_inputs(*config_and_inputs) def test_t5_model_xla_generate_fast(self):