Bart: fix layerdrop and cached decoder_input_ids for generation (#2969)
This commit is contained in:
@@ -251,6 +251,7 @@ class BartHeadTests(unittest.TestCase):
|
||||
output_past=True,
|
||||
)
|
||||
lm_model = BartForMaskedLM(config)
|
||||
lm_model.eval()
|
||||
new_input_ids = lm_model.generate(input_ids)
|
||||
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], 20))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user