[BART] generation_mode as a kwarg not a class attribute (#3278)
This commit is contained in:
@@ -437,7 +437,6 @@ class BartDecoder(nn.Module):
|
||||
[DecoderLayer(config) for _ in range(config.decoder_layers)]
|
||||
) # type: List[DecoderLayer]
|
||||
self.layernorm_embedding = LayerNorm(config.d_model)
|
||||
self.generation_mode = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -446,6 +445,7 @@ class BartDecoder(nn.Module):
|
||||
encoder_padding_mask,
|
||||
combined_mask,
|
||||
decoder_cached_states=None,
|
||||
generation_mode=False,
|
||||
**unused
|
||||
):
|
||||
"""
|
||||
@@ -474,9 +474,9 @@ class BartDecoder(nn.Module):
|
||||
assert encoder_padding_mask.max() <= 0
|
||||
|
||||
# embed positions
|
||||
positions = self.embed_positions(input_ids, generation_mode=self.generation_mode)
|
||||
positions = self.embed_positions(input_ids, generation_mode=generation_mode)
|
||||
|
||||
if self.generation_mode:
|
||||
if generation_mode:
|
||||
input_ids = input_ids[:, -1:]
|
||||
positions = positions[:, -1:] # happens after we embed them
|
||||
assert input_ids.ne(self.padding_idx).any()
|
||||
@@ -820,10 +820,11 @@ class BartModel(PretrainedBartModel):
|
||||
encoder_outputs=None, # type: Tuple
|
||||
decoder_attention_mask=None,
|
||||
decoder_cached_states=None,
|
||||
generation_mode=False,
|
||||
):
|
||||
|
||||
# make masks if user doesn't supply
|
||||
if not self.decoder.generation_mode:
|
||||
if not generation_mode:
|
||||
decoder_input_ids, decoder_attention_mask = _prepare_bart_decoder_inputs(
|
||||
self.config,
|
||||
input_ids,
|
||||
@@ -842,6 +843,7 @@ class BartModel(PretrainedBartModel):
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
decoder_cached_states=decoder_cached_states,
|
||||
generation_mode=generation_mode,
|
||||
)
|
||||
# Attention and hidden_states will be [] or None if they aren't needed
|
||||
decoder_outputs = _filter_out_falsey_values(decoder_outputs) # type: tuple
|
||||
@@ -886,6 +888,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
decoder_attention_mask=None,
|
||||
decoder_cached_states=None,
|
||||
lm_labels=None,
|
||||
generation_mode=False,
|
||||
**unused
|
||||
):
|
||||
r"""
|
||||
@@ -936,6 +939,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
encoder_outputs=encoder_outputs,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
decoder_cached_states=decoder_cached_states,
|
||||
generation_mode=generation_mode,
|
||||
)
|
||||
lm_logits = self.lm_head(outputs[0])
|
||||
outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here
|
||||
@@ -963,6 +967,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
"decoder_cached_states": decoder_cached_states,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"generation_mode": True,
|
||||
}
|
||||
|
||||
def prepare_scores_for_generation(self, scores, cur_len, max_length):
|
||||
|
||||
Reference in New Issue
Block a user