fix bug in tf no_repeat_ngram_size

This commit is contained in:
Patrick von Platen
2020-03-10 20:29:30 +01:00
parent d997ac7810
commit 1ba21f96ca

View File

@@ -942,7 +942,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if no_repeat_ngram_size > 0: if no_repeat_ngram_size > 0:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams # calculate a list of banned tokens to prevent repetitively generating the same ngrams
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len) num_batch_hypotheses = batch_size * num_beams
banned_tokens = calc_banned_tokens(input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len)
# create banned_tokens boolean mask # create banned_tokens boolean mask
banned_tokens_indices_mask = [] banned_tokens_indices_mask = []
for banned_tokens_slice in banned_tokens: for banned_tokens_slice in banned_tokens: