Fix TFEncoderDecoder tests (#21301)

remove max_length=None

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-01-26 16:56:42 +01:00
committed by GitHub
parent 857bad6e53
commit 449df41f01

View File

@@ -785,7 +785,7 @@ class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
EXPECTED_SUMMARY_STUDENTS = """sae was founded in 1856, five years before the civil war. the fraternity has had to work hard to change recently. the university of oklahoma president says the university's affiliation with the fraternity is permanently done. the sae has had a string of members in recent months."""
input_dict = tokenizer(ARTICLE_STUDENTS, return_tensors="tf")
output_ids = model.generate(input_ids=input_dict["input_ids"], max_length=None).numpy().tolist()
output_ids = model.generate(input_ids=input_dict["input_ids"]).numpy().tolist()
summary = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])
@@ -793,7 +793,7 @@ class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
# Test with the TF checkpoint
model = TFEncoderDecoderModel.from_pretrained("ydshieh/bert2bert-cnn_dailymail-fp16")
output_ids = model.generate(input_ids=input_dict["input_ids"], max_length=None).numpy().tolist()
output_ids = model.generate(input_ids=input_dict["input_ids"]).numpy().tolist()
summary = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])
@@ -887,7 +887,7 @@ class TFGPT2EncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
EXPECTED_SUMMARY_STUDENTS = """SAS Alpha Epsilon suspended the students, but university president says it's permanent.\nThe fraternity has had to deal with a string of student deaths since 2010.\nSAS has more than 200,000 members, many of whom are students.\nA student died while being forced into excessive alcohol consumption."""
input_dict = tokenizer_in(ARTICLE_STUDENTS, return_tensors="tf")
output_ids = model.generate(input_ids=input_dict["input_ids"], max_length=None).numpy().tolist()
output_ids = model.generate(input_ids=input_dict["input_ids"]).numpy().tolist()
summary = tokenizer_out.batch_decode(output_ids, skip_special_tokens=True)
self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])