Do not remove half seq length in generation tests (#30016)
* remove seq length from generation tests * style and quality * [test_all] & PR suggestion Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update tests/generation/test_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * [test all] remove unused variables --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
b4fd49b6c5
commit
b1cd48740e
@@ -299,12 +299,10 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
input_ids = input_ids[:batch_size, :sequence_length]
|
||||
attention_mask = attention_mask[:batch_size, :sequence_length]
|
||||
|
||||
# generate max 3 tokens
|
||||
max_length = input_ids.shape[-1] + 3
|
||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||
config.pad_token_id = config.eos_token_id
|
||||
return config, input_ids, attention_mask, max_length
|
||||
return config, input_ids, attention_mask
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BigBirdPegasusModelTester(self)
|
||||
|
||||
Reference in New Issue
Block a user