Fixes Pegasus tokenization tests (#10671)

This commit is contained in:
Lysandre Debut
2021-03-11 13:35:50 -05:00
committed by GitHub
parent 7e4428749c
commit a637ae00c4

View File

@@ -356,7 +356,7 @@ class TFPegasusIntegrationTests(unittest.TestCase):
assert self.expected_text == generated_words assert self.expected_text == generated_words
def translate_src_text(self, **tokenizer_kwargs): def translate_src_text(self, **tokenizer_kwargs):
model_inputs = self.tokenizer(self.src_text, **tokenizer_kwargs, return_tensors="tf") model_inputs = self.tokenizer(self.src_text, **tokenizer_kwargs, padding=True, return_tensors="tf")
generated_ids = self.model.generate( generated_ids = self.model.generate(
model_inputs.input_ids, model_inputs.input_ids,
attention_mask=model_inputs.attention_mask, attention_mask=model_inputs.attention_mask,