From 3c6bf8998fb6ca5aca063fed2543b7176883b004 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Fri, 25 Sep 2020 04:24:14 -0400 Subject: [PATCH] modeling_bart: 3 small cleanups that dont change outputs (#7381) * Mbart passing * boom boom * cleaner assert * add assert * Fix tests --- src/transformers/modeling_bart.py | 11 ++++++++--- tests/test_modeling_mbart.py | 3 +-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index e48e430d34..7cf944a5c6 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -307,7 +307,7 @@ class BartEncoder(nn.Module): self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)]) self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity() # mbart has one extra layer_norm - self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None + self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None def forward( self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False @@ -550,6 +550,10 @@ class BartDecoder(nn.Module): positions = self.embed_positions(input_ids, use_cache=use_cache) if use_cache: + if input_ids.shape[1] != 1 or past_key_values is None: + # if you make this an AssertionError, test_benchmark breaks. + warnings.warn("pass decoder_past_key_value_states to use_cache") + input_ids = input_ids[:, -1:] positions = positions[:, -1:] # happens after we embed them # assert input_ids.ne(self.padding_idx).any() @@ -590,11 +594,12 @@ class BartDecoder(nn.Module): if use_cache: next_decoder_cache.append(layer_past.copy()) - if self.layer_norm and (idx == len(self.layers) - 1): # if config.add_final_layer_norm (mBART) - x = self.layer_norm(x) if output_attentions: all_self_attns += (layer_self_attn,) + if self.layer_norm: # if config.add_final_layer_norm (mBART) + x = self.layer_norm(x) + # Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim) if output_hidden_states: all_hidden_states = tuple(hidden_state.transpose(0, 1) for hidden_state in all_hidden_states) diff --git a/tests/test_modeling_mbart.py b/tests/test_modeling_mbart.py index 7c7c0e06db..1bacd0ff0d 100644 --- a/tests/test_modeling_mbart.py +++ b/tests/test_modeling_mbart.py @@ -86,8 +86,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest): batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(self.src_text).to(torch_device) translated_tokens = self.model.generate(**batch) decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) - self.assertEqual(self.tgt_text[0], decoded[0]) - self.assertEqual(self.tgt_text[1], decoded[1]) + assert self.tgt_text == decoded def test_mbart_enro_config(self): mbart_models = ["facebook/mbart-large-en-ro"]