Generation tests: don't rely on main input name (#34228)

* don't rely on main input name

* update
This commit is contained in:
Raushan Turganbay
2024-10-21 10:00:14 +02:00
committed by GitHub
parent 816f442496
commit ca541bd4f4
3 changed files with 47 additions and 36 deletions

View File

@@ -618,14 +618,6 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
def test_generate_without_input_ids(self):
pass
def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1):
# In this model, the index of `batch_size` and `sequence_length`` in `main_input` is different: they are the
# first two dimensions of the tensor.
main_input = main_input[:, :, 0]
super()._check_outputs(
output, main_input, config, use_cache=use_cache, num_return_sequences=num_return_sequences
)
def _create_and_check_torchscript(self, config, inputs_dict):
if not self.test_torchscript:
self.skipTest(reason="test_torchscript is set to False")