Fix TFEncoderDecoder tests (#21301)
remove max_length=None Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -785,7 +785,7 @@ class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
|
||||
EXPECTED_SUMMARY_STUDENTS = """sae was founded in 1856, five years before the civil war. the fraternity has had to work hard to change recently. the university of oklahoma president says the university's affiliation with the fraternity is permanently done. the sae has had a string of members in recent months."""
|
||||
|
||||
input_dict = tokenizer(ARTICLE_STUDENTS, return_tensors="tf")
|
||||
output_ids = model.generate(input_ids=input_dict["input_ids"], max_length=None).numpy().tolist()
|
||||
output_ids = model.generate(input_ids=input_dict["input_ids"]).numpy().tolist()
|
||||
summary = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])
|
||||
@@ -793,7 +793,7 @@ class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
|
||||
# Test with the TF checkpoint
|
||||
model = TFEncoderDecoderModel.from_pretrained("ydshieh/bert2bert-cnn_dailymail-fp16")
|
||||
|
||||
output_ids = model.generate(input_ids=input_dict["input_ids"], max_length=None).numpy().tolist()
|
||||
output_ids = model.generate(input_ids=input_dict["input_ids"]).numpy().tolist()
|
||||
summary = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])
|
||||
@@ -887,7 +887,7 @@ class TFGPT2EncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
|
||||
EXPECTED_SUMMARY_STUDENTS = """SAS Alpha Epsilon suspended the students, but university president says it's permanent.\nThe fraternity has had to deal with a string of student deaths since 2010.\nSAS has more than 200,000 members, many of whom are students.\nA student died while being forced into excessive alcohol consumption."""
|
||||
|
||||
input_dict = tokenizer_in(ARTICLE_STUDENTS, return_tensors="tf")
|
||||
output_ids = model.generate(input_ids=input_dict["input_ids"], max_length=None).numpy().tolist()
|
||||
output_ids = model.generate(input_ids=input_dict["input_ids"]).numpy().tolist()
|
||||
summary = tokenizer_out.batch_decode(output_ids, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])
|
||||
|
||||
Reference in New Issue
Block a user