best current version and make style
This commit is contained in:
committed by
Patrick von Platen
parent
c62444da39
commit
2acfe63964
@@ -942,7 +942,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
return outputs
|
||||
|
||||
def prepare_inputs_for_generation(self, decoder_input_ids, past, encoder_inputs, attention_mask):
|
||||
assert attention_mask.shape == encoder_inputs.shape, "attn_mask.shape != encoder_input.shape: {} =! {}".format(attention_mask.shape, encoder_inputs.shape)
|
||||
assert attention_mask.shape == encoder_inputs.shape, "attn_mask.shape != encoder_input.shape: {} =! {}".format(
|
||||
attention_mask.shape, encoder_inputs.shape
|
||||
)
|
||||
if past is None: # first step
|
||||
encoder_outputs, decoder_cached_states = None, None
|
||||
else:
|
||||
@@ -954,7 +956,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"decoder_cached_states": decoder_cached_states,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
def prepare_scores_for_generation(self, scores, cur_len, max_length):
|
||||
|
||||
Reference in New Issue
Block a user