From 11573231c6806e1864840a09060cafe78bc6acf8 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 16 Mar 2020 12:47:53 -0400 Subject: [PATCH] [BART] generation_mode as a kwarg not a class attribute (#3278) --- src/transformers/modeling_bart.py | 13 +++++++++---- src/transformers/modeling_utils.py | 4 ---- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index 73d2d22898..5976d49074 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -437,7 +437,6 @@ class BartDecoder(nn.Module): [DecoderLayer(config) for _ in range(config.decoder_layers)] ) # type: List[DecoderLayer] self.layernorm_embedding = LayerNorm(config.d_model) - self.generation_mode = False def forward( self, @@ -446,6 +445,7 @@ class BartDecoder(nn.Module): encoder_padding_mask, combined_mask, decoder_cached_states=None, + generation_mode=False, **unused ): """ @@ -474,9 +474,9 @@ class BartDecoder(nn.Module): assert encoder_padding_mask.max() <= 0 # embed positions - positions = self.embed_positions(input_ids, generation_mode=self.generation_mode) + positions = self.embed_positions(input_ids, generation_mode=generation_mode) - if self.generation_mode: + if generation_mode: input_ids = input_ids[:, -1:] positions = positions[:, -1:] # happens after we embed them assert input_ids.ne(self.padding_idx).any() @@ -820,10 +820,11 @@ class BartModel(PretrainedBartModel): encoder_outputs=None, # type: Tuple decoder_attention_mask=None, decoder_cached_states=None, + generation_mode=False, ): # make masks if user doesn't supply - if not self.decoder.generation_mode: + if not generation_mode: decoder_input_ids, decoder_attention_mask = _prepare_bart_decoder_inputs( self.config, input_ids, @@ -842,6 +843,7 @@ class BartModel(PretrainedBartModel): attention_mask, decoder_attention_mask, decoder_cached_states=decoder_cached_states, + generation_mode=generation_mode, ) # Attention and hidden_states will be [] or None if they aren't needed decoder_outputs = _filter_out_falsey_values(decoder_outputs) # type: tuple @@ -886,6 +888,7 @@ class BartForConditionalGeneration(PretrainedBartModel): decoder_attention_mask=None, decoder_cached_states=None, lm_labels=None, + generation_mode=False, **unused ): r""" @@ -936,6 +939,7 @@ class BartForConditionalGeneration(PretrainedBartModel): encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, decoder_cached_states=decoder_cached_states, + generation_mode=generation_mode, ) lm_logits = self.lm_head(outputs[0]) outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here @@ -963,6 +967,7 @@ class BartForConditionalGeneration(PretrainedBartModel): "decoder_cached_states": decoder_cached_states, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "generation_mode": True, } def prepare_scores_for_generation(self, scores, cur_len, max_length): diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 57b4204a53..467c329f5e 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -846,7 +846,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): attention_mask = attention_mask.contiguous().view( effective_batch_size * num_beams, input_ids_len ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) - if self.config.is_encoder_decoder: assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id" # encoder decoder need to start with empty input_ids and copy the input_ids to encoder_inputs @@ -859,9 +858,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ) cur_len = 1 - # put model in generation mode if it has one - if hasattr(self.model, "decoder") and hasattr(self.model.decoder, "generation_mode"): - self.model.decoder.generation_mode = True else: encoder_inputs = None cur_len = input_ids.shape[-1]