best current version and make style

This commit is contained in:
patrickvonplaten
2020-03-06 22:19:01 +01:00
committed by Patrick von Platen
parent c62444da39
commit 2acfe63964
4 changed files with 45 additions and 36 deletions

View File

@@ -942,7 +942,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
return outputs
def prepare_inputs_for_generation(self, decoder_input_ids, past, encoder_inputs, attention_mask):
assert attention_mask.shape == encoder_inputs.shape, "attn_mask.shape != encoder_input.shape: {} =! {}".format(attention_mask.shape, encoder_inputs.shape)
assert attention_mask.shape == encoder_inputs.shape, "attn_mask.shape != encoder_input.shape: {} =! {}".format(
attention_mask.shape, encoder_inputs.shape
)
if past is None: # first step
encoder_outputs, decoder_cached_states = None, None
else:
@@ -954,7 +956,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
"encoder_outputs": encoder_outputs,
"decoder_cached_states": decoder_cached_states,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask
"attention_mask": attention_mask,
}
def prepare_scores_for_generation(self, scores, cur_len, max_length):