[BART] generation_mode as a kwarg not a class attribute (#3278)

This commit is contained in:
Sam Shleifer
2020-03-16 12:47:53 -04:00
committed by GitHub
parent de697935a2
commit 11573231c6
2 changed files with 9 additions and 8 deletions

View File

@@ -437,7 +437,6 @@ class BartDecoder(nn.Module):
[DecoderLayer(config) for _ in range(config.decoder_layers)]
) # type: List[DecoderLayer]
self.layernorm_embedding = LayerNorm(config.d_model)
self.generation_mode = False
def forward(
self,
@@ -446,6 +445,7 @@ class BartDecoder(nn.Module):
encoder_padding_mask,
combined_mask,
decoder_cached_states=None,
generation_mode=False,
**unused
):
"""
@@ -474,9 +474,9 @@ class BartDecoder(nn.Module):
assert encoder_padding_mask.max() <= 0
# embed positions
positions = self.embed_positions(input_ids, generation_mode=self.generation_mode)
positions = self.embed_positions(input_ids, generation_mode=generation_mode)
if self.generation_mode:
if generation_mode:
input_ids = input_ids[:, -1:]
positions = positions[:, -1:] # happens after we embed them
assert input_ids.ne(self.padding_idx).any()
@@ -820,10 +820,11 @@ class BartModel(PretrainedBartModel):
encoder_outputs=None, # type: Tuple
decoder_attention_mask=None,
decoder_cached_states=None,
generation_mode=False,
):
# make masks if user doesn't supply
if not self.decoder.generation_mode:
if not generation_mode:
decoder_input_ids, decoder_attention_mask = _prepare_bart_decoder_inputs(
self.config,
input_ids,
@@ -842,6 +843,7 @@ class BartModel(PretrainedBartModel):
attention_mask,
decoder_attention_mask,
decoder_cached_states=decoder_cached_states,
generation_mode=generation_mode,
)
# Attention and hidden_states will be [] or None if they aren't needed
decoder_outputs = _filter_out_falsey_values(decoder_outputs) # type: tuple
@@ -886,6 +888,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
decoder_attention_mask=None,
decoder_cached_states=None,
lm_labels=None,
generation_mode=False,
**unused
):
r"""
@@ -936,6 +939,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
decoder_cached_states=decoder_cached_states,
generation_mode=generation_mode,
)
lm_logits = self.lm_head(outputs[0])
outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here
@@ -963,6 +967,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
"decoder_cached_states": decoder_cached_states,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"generation_mode": True,
}
def prepare_scores_for_generation(self, scores, cur_len, max_length):