fix bug in tf no_repeat_ngram_size
This commit is contained in:
@@ -942,7 +942,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
if no_repeat_ngram_size > 0:
|
if no_repeat_ngram_size > 0:
|
||||||
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
||||||
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
||||||
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
|
num_batch_hypotheses = batch_size * num_beams
|
||||||
|
banned_tokens = calc_banned_tokens(input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len)
|
||||||
# create banned_tokens boolean mask
|
# create banned_tokens boolean mask
|
||||||
banned_tokens_indices_mask = []
|
banned_tokens_indices_mask = []
|
||||||
for banned_tokens_slice in banned_tokens:
|
for banned_tokens_slice in banned_tokens:
|
||||||
|
|||||||
Reference in New Issue
Block a user