tf generation utils: remove unused kwargs (#6591)
This commit is contained in:
@@ -284,7 +284,7 @@ class TFGenerationMixin:
|
||||
pad_token_id = eos_token_id
|
||||
|
||||
# current position and vocab size
|
||||
cur_len = shape_list(input_ids)[1]
|
||||
cur_len = shape_list(input_ids)[1] # unused
|
||||
vocab_size = self.config.vocab_size
|
||||
|
||||
# set effective batch size and effective batch multiplier according to do_sample
|
||||
@@ -366,10 +366,8 @@ class TFGenerationMixin:
|
||||
repetition_penalty=repetition_penalty,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
bad_words_ids=bad_words_ids,
|
||||
bos_token_id=bos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
batch_size=effective_batch_size,
|
||||
num_return_sequences=num_return_sequences,
|
||||
length_penalty=length_penalty,
|
||||
@@ -392,10 +390,8 @@ class TFGenerationMixin:
|
||||
repetition_penalty=repetition_penalty,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
bad_words_ids=bad_words_ids,
|
||||
bos_token_id=bos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
batch_size=effective_batch_size,
|
||||
vocab_size=vocab_size,
|
||||
encoder_outputs=encoder_outputs,
|
||||
@@ -418,10 +414,8 @@ class TFGenerationMixin:
|
||||
repetition_penalty,
|
||||
no_repeat_ngram_size,
|
||||
bad_words_ids,
|
||||
bos_token_id,
|
||||
pad_token_id,
|
||||
eos_token_id,
|
||||
decoder_start_token_id,
|
||||
batch_size,
|
||||
vocab_size,
|
||||
encoder_outputs,
|
||||
@@ -582,9 +576,7 @@ class TFGenerationMixin:
|
||||
repetition_penalty,
|
||||
no_repeat_ngram_size,
|
||||
bad_words_ids,
|
||||
bos_token_id,
|
||||
pad_token_id,
|
||||
decoder_start_token_id,
|
||||
eos_token_id,
|
||||
batch_size,
|
||||
num_return_sequences,
|
||||
@@ -616,6 +608,7 @@ class TFGenerationMixin:
|
||||
|
||||
# cache compute states
|
||||
past = encoder_outputs
|
||||
# to stay similar to torch : past = (encoder_outputs, None) if encoder_outputs is not None else None
|
||||
|
||||
# done sentences
|
||||
done = [False for _ in range(batch_size)]
|
||||
|
||||
Reference in New Issue
Block a user