Fixes Pegasus tokenization tests (#10671)
This commit is contained in:
@@ -356,7 +356,7 @@ class TFPegasusIntegrationTests(unittest.TestCase):
|
|||||||
assert self.expected_text == generated_words
|
assert self.expected_text == generated_words
|
||||||
|
|
||||||
def translate_src_text(self, **tokenizer_kwargs):
|
def translate_src_text(self, **tokenizer_kwargs):
|
||||||
model_inputs = self.tokenizer(self.src_text, **tokenizer_kwargs, return_tensors="tf")
|
model_inputs = self.tokenizer(self.src_text, **tokenizer_kwargs, padding=True, return_tensors="tf")
|
||||||
generated_ids = self.model.generate(
|
generated_ids = self.model.generate(
|
||||||
model_inputs.input_ids,
|
model_inputs.input_ids,
|
||||||
attention_mask=model_inputs.attention_mask,
|
attention_mask=model_inputs.attention_mask,
|
||||||
|
|||||||
Reference in New Issue
Block a user