tf generation utils: remove unused kwargs (#6591)
This commit is contained in:
@@ -284,7 +284,7 @@ class TFGenerationMixin:
|
|||||||
pad_token_id = eos_token_id
|
pad_token_id = eos_token_id
|
||||||
|
|
||||||
# current position and vocab size
|
# current position and vocab size
|
||||||
cur_len = shape_list(input_ids)[1]
|
cur_len = shape_list(input_ids)[1] # unused
|
||||||
vocab_size = self.config.vocab_size
|
vocab_size = self.config.vocab_size
|
||||||
|
|
||||||
# set effective batch size and effective batch multiplier according to do_sample
|
# set effective batch size and effective batch multiplier according to do_sample
|
||||||
@@ -366,10 +366,8 @@ class TFGenerationMixin:
|
|||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||||
bad_words_ids=bad_words_ids,
|
bad_words_ids=bad_words_ids,
|
||||||
bos_token_id=bos_token_id,
|
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
decoder_start_token_id=decoder_start_token_id,
|
|
||||||
batch_size=effective_batch_size,
|
batch_size=effective_batch_size,
|
||||||
num_return_sequences=num_return_sequences,
|
num_return_sequences=num_return_sequences,
|
||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
@@ -392,10 +390,8 @@ class TFGenerationMixin:
|
|||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||||
bad_words_ids=bad_words_ids,
|
bad_words_ids=bad_words_ids,
|
||||||
bos_token_id=bos_token_id,
|
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
decoder_start_token_id=decoder_start_token_id,
|
|
||||||
batch_size=effective_batch_size,
|
batch_size=effective_batch_size,
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
@@ -418,10 +414,8 @@ class TFGenerationMixin:
|
|||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
no_repeat_ngram_size,
|
no_repeat_ngram_size,
|
||||||
bad_words_ids,
|
bad_words_ids,
|
||||||
bos_token_id,
|
|
||||||
pad_token_id,
|
pad_token_id,
|
||||||
eos_token_id,
|
eos_token_id,
|
||||||
decoder_start_token_id,
|
|
||||||
batch_size,
|
batch_size,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
encoder_outputs,
|
encoder_outputs,
|
||||||
@@ -582,9 +576,7 @@ class TFGenerationMixin:
|
|||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
no_repeat_ngram_size,
|
no_repeat_ngram_size,
|
||||||
bad_words_ids,
|
bad_words_ids,
|
||||||
bos_token_id,
|
|
||||||
pad_token_id,
|
pad_token_id,
|
||||||
decoder_start_token_id,
|
|
||||||
eos_token_id,
|
eos_token_id,
|
||||||
batch_size,
|
batch_size,
|
||||||
num_return_sequences,
|
num_return_sequences,
|
||||||
@@ -616,6 +608,7 @@ class TFGenerationMixin:
|
|||||||
|
|
||||||
# cache compute states
|
# cache compute states
|
||||||
past = encoder_outputs
|
past = encoder_outputs
|
||||||
|
# to stay similar to torch : past = (encoder_outputs, None) if encoder_outputs is not None else None
|
||||||
|
|
||||||
# done sentences
|
# done sentences
|
||||||
done = [False for _ in range(batch_size)]
|
done = [False for _ in range(batch_size)]
|
||||||
|
|||||||
Reference in New Issue
Block a user