Bart: fix layerdrop and cached decoder_input_ids for generation (#2969)

This commit is contained in:
Sam Shleifer
2020-02-22 16:25:04 -05:00
committed by GitHub
parent c36416e53c
commit 92487a1dc0
2 changed files with 3 additions and 6 deletions

View File

@@ -251,6 +251,7 @@ class BartHeadTests(unittest.TestCase):
output_past=True,
)
lm_model = BartForMaskedLM(config)
lm_model.eval()
new_input_ids = lm_model.generate(input_ids)
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], 20))