From 1ba21f96ca0795095a32be485c7100b5e2062592 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 10 Mar 2020 20:29:30 +0100 Subject: [PATCH] fix bug in tf no_repeat_ngram_size --- src/transformers/modeling_tf_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 8cab7619a0..a0247015bd 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -942,7 +942,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): if no_repeat_ngram_size > 0: # calculate a list of banned tokens to prevent repetitively generating the same ngrams # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 - banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len) + num_batch_hypotheses = batch_size * num_beams + banned_tokens = calc_banned_tokens(input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len) # create banned_tokens boolean mask banned_tokens_indices_mask = [] for banned_tokens_slice in banned_tokens: