Do not remove half seq length in generation tests (#30016)

* remove seq length from generation tests

* style and quality

* [test_all] & PR suggestion

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update tests/generation/test_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* [test all] remove unused variables

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2024-04-19 21:32:52 +05:00
committed by GitHub
parent b4fd49b6c5
commit b1cd48740e
10 changed files with 180 additions and 261 deletions

View File

@@ -646,7 +646,8 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
seq_len = 1
else:
# for first item dummy PAD token is appended so need one more
seq_len = (min_length + 1) if idx == 0 else min_length
# else offset+dummy_token when using cache
seq_len = (min_length + 1) if idx == 0 else 3
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
self.assertEqual(layer_hidden_states.shape, expected_shape)
@@ -665,8 +666,11 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
tgt_len = min_length
# for first item dummy PAD token is appended so need one more
# every token after consists of offset+dummy_token length when using cache
if idx == 0:
tgt_len += 1
else:
tgt_len = 3
src_len = min_length + idx + 1