From 92487a1dc03c919afa8a961ed7d8ba78fafa21bd Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Sat, 22 Feb 2020 16:25:04 -0500 Subject: [PATCH] Bart: fix layerdrop and cached decoder_input_ids for generation (#2969) --- src/transformers/modeling_bart.py | 8 ++------ tests/test_modeling_bart.py | 1 + 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index 3ec76cf675..f329eb6842 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -470,10 +470,6 @@ class BartDecoder(nn.Module): """ # embed positions positions = self.embed_positions(input_ids) - - if decoder_cached_states is not None: - input_ids = input_ids[:, -1:] - positions = positions[:, -1:] x = self.embed_tokens(input_ids) if positions is not None: @@ -491,7 +487,7 @@ class BartDecoder(nn.Module): decoder_layer # type: DecoderLayer # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = random.uniform(0, 1) - if self.training and (dropout_probability > self.layerdrop): + if self.training and (dropout_probability < self.layerdrop): continue layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None x, layer_self_attn, layer_past = decoder_layer.forward( @@ -940,7 +936,7 @@ class BartForMaskedLM(PretrainedBartModel): @staticmethod def prepare_inputs_for_generation(input_ids, past, **kwargs): - return {"input_ids": input_ids, "decoder_cached_states": past, "decoder_input_ids": input_ids} + return {"input_ids": input_ids, "decoder_cached_states": past, "decoder_input_ids": input_ids[:, -1:]} def get_output_embeddings(self): return self.lm_head diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 61f6098610..927c37eadf 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -251,6 +251,7 @@ class BartHeadTests(unittest.TestCase): output_past=True, ) lm_model = BartForMaskedLM(config) + lm_model.eval() new_input_ids = lm_model.generate(input_ids) self.assertEqual(new_input_ids.shape, (input_ids.shape[0], 20))