VLMs: enable generation tests (#33533)

* add tests

* fix whisper

* update

* nit

* add qwen2-vl

* more updates!

* better this way

* fix this one

* fix more tests

* fix final tests, hope so

* fix led

* Update tests/generation/test_utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* pr comments

* not pass pixels and extra for low-mem tests, very flaky because of visio tower

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Raushan Turganbay
2024-09-19 12:04:24 +02:00
committed by GitHub
parent e40bb4845e
commit d7975a5874
22 changed files with 500 additions and 207 deletions

View File

@@ -286,12 +286,19 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"]
_ = inputs_dict.pop("attention_mask", None)
inputs_dict = {
k: v[:batch_size, ...]
for k, v in inputs_dict.items()
if "head_mask" not in k and isinstance(v, torch.Tensor)
}
# take max batch_size
sequence_length = input_ids.shape[-1]
input_ids = input_ids[: batch_size * config.num_codebooks, :]
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
return config, input_ids, attention_mask
return config, input_ids, attention_mask, inputs_dict
def _get_logits_processor_kwargs(self, do_sample=False):
logits_processor_kwargs = {}
@@ -299,7 +306,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
config.audio_channels = 2
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
@@ -310,6 +317,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
inputs_dict={},
)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)