VLMs: enable generation tests (#33533)
* add tests * fix whisper * update * nit * add qwen2-vl * more updates! * better this way * fix this one * fix more tests * fix final tests, hope so * fix led * Update tests/generation/test_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * pr comments * not pass pixels and extra for low-mem tests, very flaky because of visio tower --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
committed by
GitHub
parent
e40bb4845e
commit
d7975a5874
@@ -286,12 +286,19 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
|
||||
_ = inputs_dict.pop("attention_mask", None)
|
||||
inputs_dict = {
|
||||
k: v[:batch_size, ...]
|
||||
for k, v in inputs_dict.items()
|
||||
if "head_mask" not in k and isinstance(v, torch.Tensor)
|
||||
}
|
||||
|
||||
# take max batch_size
|
||||
sequence_length = input_ids.shape[-1]
|
||||
input_ids = input_ids[: batch_size * config.num_codebooks, :]
|
||||
|
||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
||||
return config, input_ids, attention_mask
|
||||
return config, input_ids, attention_mask, inputs_dict
|
||||
|
||||
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||
logits_processor_kwargs = {}
|
||||
@@ -299,7 +306,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
|
||||
def test_greedy_generate_stereo_outputs(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
config.audio_channels = 2
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._greedy_generate(
|
||||
@@ -310,6 +317,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
inputs_dict={},
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
|
||||
Reference in New Issue
Block a user