work in progress
This commit is contained in:
@@ -957,7 +957,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def prepare_inputs_for_generation(decoder_input_ids, past, encoder_inputs):
|
||||
def prepare_inputs_for_generation(decoder_input_ids, past, encoder_inputs, attention_mask):
|
||||
assert attention_mask.shape == encoder_inputs.shape, "attn_mask.shape != encoder_input.shape: {} =! {}".format(attention_mask.shape, encoder_inputs.shape)
|
||||
if past is None: # first step
|
||||
encoder_outputs, decoder_cached_states = None, None
|
||||
else:
|
||||
@@ -969,6 +970,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"decoder_cached_states": decoder_cached_states,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -1132,6 +1134,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
lprobs[:, eos_token_id + 1 :] = -math.inf
|
||||
assert self._do_output_past(outputs)
|
||||
decoder_cache = outputs[1]
|
||||
|
||||
if repetition_penalty != 1.0:
|
||||
self.enforce_repetition_penalty_(lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty)
|
||||
num_hypos = batch_size * num_beams
|
||||
|
||||
Reference in New Issue
Block a user