tf generation utils: remove unused kwargs (#6591)

This commit is contained in:
Sam Shleifer
2020-08-19 09:37:45 -04:00
committed by GitHub
parent 2a7402cbd3
commit 9a86321b11

View File

@@ -284,7 +284,7 @@ class TFGenerationMixin:
pad_token_id = eos_token_id
# current position and vocab size
cur_len = shape_list(input_ids)[1]
cur_len = shape_list(input_ids)[1] # unused
vocab_size = self.config.vocab_size
# set effective batch size and effective batch multiplier according to do_sample
@@ -366,10 +366,8 @@ class TFGenerationMixin:
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
batch_size=effective_batch_size,
num_return_sequences=num_return_sequences,
length_penalty=length_penalty,
@@ -392,10 +390,8 @@ class TFGenerationMixin:
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
batch_size=effective_batch_size,
vocab_size=vocab_size,
encoder_outputs=encoder_outputs,
@@ -418,10 +414,8 @@ class TFGenerationMixin:
repetition_penalty,
no_repeat_ngram_size,
bad_words_ids,
bos_token_id,
pad_token_id,
eos_token_id,
decoder_start_token_id,
batch_size,
vocab_size,
encoder_outputs,
@@ -582,9 +576,7 @@ class TFGenerationMixin:
repetition_penalty,
no_repeat_ngram_size,
bad_words_ids,
bos_token_id,
pad_token_id,
decoder_start_token_id,
eos_token_id,
batch_size,
num_return_sequences,
@@ -616,6 +608,7 @@ class TFGenerationMixin:
# cache compute states
past = encoder_outputs
# to stay similar to torch : past = (encoder_outputs, None) if encoder_outputs is not None else None
# done sentences
done = [False for _ in range(batch_size)]