work in progress

This commit is contained in:
Patrick von Platen
2020-03-06 14:39:28 +01:00
parent 5b3000d933
commit 7a11e925cf
6 changed files with 176 additions and 39 deletions

View File

@@ -957,7 +957,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
}
@staticmethod
def prepare_inputs_for_generation(decoder_input_ids, past, encoder_inputs):
def prepare_inputs_for_generation(decoder_input_ids, past, encoder_inputs, attention_mask):
assert attention_mask.shape == encoder_inputs.shape, "attn_mask.shape != encoder_input.shape: {} =! {}".format(attention_mask.shape, encoder_inputs.shape)
if past is None: # first step
encoder_outputs, decoder_cached_states = None, None
else:
@@ -969,6 +970,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
"encoder_outputs": encoder_outputs,
"decoder_cached_states": decoder_cached_states,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask
}
@staticmethod
@@ -1132,6 +1134,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
lprobs[:, eos_token_id + 1 :] = -math.inf
assert self._do_output_past(outputs)
decoder_cache = outputs[1]
if repetition_penalty != 1.0:
self.enforce_repetition_penalty_(lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty)
num_hypos = batch_size * num_beams