From 9a86321b11b0ee03d8803d7e21b50012252a76ac Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Wed, 19 Aug 2020 09:37:45 -0400 Subject: [PATCH] tf generation utils: remove unused kwargs (#6591) --- src/transformers/generation_tf_utils.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index 4e76725947..20b49764fc 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -284,7 +284,7 @@ class TFGenerationMixin: pad_token_id = eos_token_id # current position and vocab size - cur_len = shape_list(input_ids)[1] + cur_len = shape_list(input_ids)[1] # unused vocab_size = self.config.vocab_size # set effective batch size and effective batch multiplier according to do_sample @@ -366,10 +366,8 @@ class TFGenerationMixin: repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, bad_words_ids=bad_words_ids, - bos_token_id=bos_token_id, pad_token_id=pad_token_id, eos_token_id=eos_token_id, - decoder_start_token_id=decoder_start_token_id, batch_size=effective_batch_size, num_return_sequences=num_return_sequences, length_penalty=length_penalty, @@ -392,10 +390,8 @@ class TFGenerationMixin: repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, bad_words_ids=bad_words_ids, - bos_token_id=bos_token_id, pad_token_id=pad_token_id, eos_token_id=eos_token_id, - decoder_start_token_id=decoder_start_token_id, batch_size=effective_batch_size, vocab_size=vocab_size, encoder_outputs=encoder_outputs, @@ -418,10 +414,8 @@ class TFGenerationMixin: repetition_penalty, no_repeat_ngram_size, bad_words_ids, - bos_token_id, pad_token_id, eos_token_id, - decoder_start_token_id, batch_size, vocab_size, encoder_outputs, @@ -582,9 +576,7 @@ class TFGenerationMixin: repetition_penalty, no_repeat_ngram_size, bad_words_ids, - bos_token_id, pad_token_id, - decoder_start_token_id, eos_token_id, batch_size, num_return_sequences, @@ -616,6 +608,7 @@ class TFGenerationMixin: # cache compute states past = encoder_outputs + # to stay similar to torch : past = (encoder_outputs, None) if encoder_outputs is not None else None # done sentences done = [False for _ in range(batch_size)]