Generation tests: don't rely on main input name (#34228)

* don't rely on main input name

* update
This commit is contained in:
Raushan Turganbay
2024-10-21 10:00:14 +02:00
committed by GitHub
parent 816f442496
commit ca541bd4f4
3 changed files with 47 additions and 36 deletions

View File

@@ -53,6 +53,7 @@ class ReformerModelTester:
parent,
batch_size=13,
seq_length=32,
text_seq_length=None,
is_training=True,
is_decoder=True,
use_input_mask=True,
@@ -128,6 +129,7 @@ class ReformerModelTester:
self.attn_layers = attn_layers
self.pad_token_id = pad_token_id
self.hash_seed = hash_seed
self.text_seq_length = text_seq_length or seq_length
attn_chunk_length = local_attn_chunk_length if local_attn_chunk_length is not None else lsh_attn_chunk_length
num_chunks_after = local_num_chunks_after if local_num_chunks_after is not None else lsh_num_chunks_after
@@ -608,7 +610,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
test_sequence_classification_problem_types = True
def setUp(self):
self.model_tester = ReformerModelTester(self)
self.model_tester = ReformerModelTester(self, text_seq_length=16)
self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37)
@slow
@@ -689,7 +691,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
# decreasing the seq_length in tester causes errors for "training_tests", those need exactly max seq length
# NOTE: seq_length has to be multiple of 4, otherwise it fails for other tests
original_sequence_length = self.model_tester.seq_length
self.model_tester.seq_length = 16
self.model_tester.seq_length = self.model_tester.text_seq_length
test_inputs = super().prepare_config_and_inputs_for_generate(*args, **kwargs)
self.model_tester.seq_length = original_sequence_length
return test_inputs

View File

@@ -618,14 +618,6 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
def test_generate_without_input_ids(self):
pass
def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1):
# In this model, the index of `batch_size` and `sequence_length`` in `main_input` is different: they are the
# first two dimensions of the tensor.
main_input = main_input[:, :, 0]
super()._check_outputs(
output, main_input, config, use_cache=use_cache, num_return_sequences=num_return_sequences
)
def _create_and_check_torchscript(self, config, inputs_dict):
if not self.test_torchscript:
self.skipTest(reason="test_torchscript is set to False")