Do not remove half seq length in generation tests (#30016)

* remove seq length from generation tests

* style and quality

* [test_all] & PR suggestion

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update tests/generation/test_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* [test all] remove unused variables

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2024-04-19 21:32:52 +05:00
committed by GitHub
parent b4fd49b6c5
commit b1cd48740e
10 changed files with 180 additions and 261 deletions

View File

@@ -299,12 +299,10 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
input_ids = input_ids[:batch_size, :sequence_length]
attention_mask = attention_mask[:batch_size, :sequence_length]
# generate max 3 tokens
max_length = input_ids.shape[-1] + 3
if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()`
config.pad_token_id = config.eos_token_id
return config, input_ids, attention_mask, max_length
return config, input_ids, attention_mask
def setUp(self):
self.model_tester = BigBirdPegasusModelTester(self)