From 1a5c500f1232f5fd21caeab918b1534622926029 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 20 Mar 2024 07:45:53 +0000 Subject: [PATCH] Tests: Musicgen tests + `make fix-copies` (#29734) * make fix-copies * some tests fixed * tests fixed --- .../modeling_musicgen_melody.py | 4 +- .../models/musicgen/test_modeling_musicgen.py | 99 -------- .../test_modeling_musicgen_melody.py | 238 ++---------------- 3 files changed, 27 insertions(+), 314 deletions(-) diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 660e37b451..8b5c5c2f57 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -1294,7 +1294,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel): ) # 11. run greedy search - outputs = self.greedy_search( + outputs = self._greedy_search( input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria, @@ -1319,7 +1319,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel): ) # 12. run sample - outputs = self.sample( + outputs = self._sample( input_ids, logits_processor=logits_processor, logits_warper=logits_warper, diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index f7ceb0a8bf..adc3bf234e 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -257,105 +257,6 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste warper_kwargs = {} return process_kwargs, warper_kwargs - # override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform - # additional post-processing in the former - def test_greedy_generate_dict_outputs(self): - for model_class in self.greedy_sample_model_classes: - # disable cache - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - config.use_cache = False - model = model_class(config).to(torch_device).eval() - output_generate = self._greedy_generate( - model=model, - input_ids=input_ids.to(torch_device), - attention_mask=attention_mask.to(torch_device), - max_length=max_length, - output_scores=True, - output_hidden_states=True, - output_attentions=True, - return_dict_in_generate=True, - ) - - self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) - self.assertNotIn(config.pad_token_id, output_generate) - - # override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform - # additional post-processing in the former - def test_greedy_generate_dict_outputs_use_cache(self): - for model_class in self.greedy_sample_model_classes: - # enable cache - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - - config.use_cache = True - config.is_decoder = True - model = model_class(config).to(torch_device).eval() - output_generate = self._greedy_generate( - model=model, - input_ids=input_ids.to(torch_device), - attention_mask=attention_mask.to(torch_device), - max_length=max_length, - output_scores=True, - output_hidden_states=True, - output_attentions=True, - return_dict_in_generate=True, - ) - - self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) - - # override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform - # additional post-processing in the former - def test_sample_generate(self): - for model_class in self.greedy_sample_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - model = model_class(config).to(torch_device).eval() - - process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - max_length=max_length, - ) - - # check `generate()` and `sample()` are equal - output_generate = self._sample_generate( - model=model, - input_ids=input_ids.to(torch_device), - attention_mask=attention_mask.to(torch_device), - max_length=max_length, - num_return_sequences=3, - logits_warper_kwargs=logits_warper_kwargs, - process_kwargs=process_kwargs, - ) - self.assertIsInstance(output_generate, torch.Tensor) - - # override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform - # additional post-processing in the former - def test_sample_generate_dict_output(self): - for model_class in self.greedy_sample_model_classes: - # disable cache - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - config.use_cache = False - model = model_class(config).to(torch_device).eval() - - process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - max_length=max_length, - ) - - output_generate = self._sample_generate( - model=model, - input_ids=input_ids.to(torch_device), - attention_mask=attention_mask.to(torch_device), - max_length=max_length, - num_return_sequences=1, - logits_warper_kwargs=logits_warper_kwargs, - process_kwargs=process_kwargs, - output_scores=True, - output_hidden_states=True, - output_attentions=True, - return_dict_in_generate=True, - ) - - self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) - def test_greedy_generate_stereo_outputs(self): for model_class in self.greedy_sample_model_classes: config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py index af10aa8846..7bb346d8ab 100644 --- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py +++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py @@ -55,8 +55,6 @@ if is_torch_available(): ) from transformers.generation import ( GenerateDecoderOnlyOutput, - InfNanRemoveLogitsProcessor, - LogitsProcessorList, ) if is_torchaudio_available(): @@ -248,142 +246,24 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes return config, input_ids, attention_mask, max_length @staticmethod - def _get_logits_processor_and_kwargs( + def _get_logits_processor_and_warper_kwargs( input_length, - eos_token_id, forced_bos_token_id=None, forced_eos_token_id=None, max_length=None, - diversity_penalty=None, ): process_kwargs = { "min_length": input_length + 1 if max_length is None else max_length - 1, } - logits_processor = LogitsProcessorList() - return process_kwargs, logits_processor - - # override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform - # additional post-processing in the former - def test_greedy_generate_dict_outputs(self): - for model_class in self.greedy_sample_model_classes: - # disable cache - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - config.use_cache = False - model = model_class(config).to(torch_device).eval() - output_greedy, output_generate = self._greedy_generate( - model=model, - input_ids=input_ids.to(torch_device), - attention_mask=attention_mask.to(torch_device), - max_length=max_length, - output_scores=True, - output_hidden_states=True, - output_attentions=True, - return_dict_in_generate=True, - ) - - self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput) - self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) - - self.assertNotIn(config.pad_token_id, output_generate) - - # override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform - # additional post-processing in the former - def test_greedy_generate_dict_outputs_use_cache(self): - for model_class in self.greedy_sample_model_classes: - # enable cache - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - - config.use_cache = True - config.is_decoder = True - model = model_class(config).to(torch_device).eval() - output_greedy, output_generate = self._greedy_generate( - model=model, - input_ids=input_ids.to(torch_device), - attention_mask=attention_mask.to(torch_device), - max_length=max_length, - output_scores=True, - output_hidden_states=True, - output_attentions=True, - return_dict_in_generate=True, - ) - - self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput) - self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) - - # override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform - # additional post-processing in the former - def test_sample_generate(self): - for model_class in self.greedy_sample_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - model = model_class(config).to(torch_device).eval() - - process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( - input_ids.shape[-1], - model.config.eos_token_id, - forced_bos_token_id=model.config.forced_bos_token_id, - forced_eos_token_id=model.config.forced_eos_token_id, - max_length=max_length, - ) - logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2) - - # check `generate()` and `sample()` are equal - output_sample, output_generate = self._sample_generate( - model=model, - input_ids=input_ids.to(torch_device), - attention_mask=attention_mask.to(torch_device), - max_length=max_length, - num_return_sequences=3, - logits_processor=logits_processor, - logits_warper=logits_warper, - logits_warper_kwargs=logits_warper_kwargs, - process_kwargs=process_kwargs, - ) - self.assertIsInstance(output_sample, torch.Tensor) - self.assertIsInstance(output_generate, torch.Tensor) - - # override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform - # additional post-processing in the former - def test_sample_generate_dict_output(self): - for model_class in self.greedy_sample_model_classes: - # disable cache - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - config.use_cache = False - model = model_class(config).to(torch_device).eval() - - process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( - input_ids.shape[-1], - model.config.eos_token_id, - forced_bos_token_id=model.config.forced_bos_token_id, - forced_eos_token_id=model.config.forced_eos_token_id, - max_length=max_length, - ) - logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) - - output_sample, output_generate = self._sample_generate( - model=model, - input_ids=input_ids.to(torch_device), - attention_mask=attention_mask.to(torch_device), - max_length=max_length, - num_return_sequences=1, - logits_processor=logits_processor, - logits_warper=logits_warper, - logits_warper_kwargs=logits_warper_kwargs, - process_kwargs=process_kwargs, - output_scores=True, - output_hidden_states=True, - output_attentions=True, - return_dict_in_generate=True, - ) - - self.assertIsInstance(output_sample, GenerateDecoderOnlyOutput) - self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) + warper_kwargs = {} + return process_kwargs, warper_kwargs def test_greedy_generate_stereo_outputs(self): for model_class in self.greedy_sample_model_classes: config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config.audio_channels = 2 model = model_class(config).to(torch_device).eval() - output_greedy, output_generate = self._greedy_generate( + output_generate = self._greedy_generate( model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), @@ -394,9 +274,7 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes return_dict_in_generate=True, ) - self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) - self.assertNotIn(config.pad_token_id, output_generate) @@ -817,10 +695,8 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long) # generate max 3 tokens - decoder_input_ids = inputs_dict["decoder_input_ids"] - max_length = decoder_input_ids.shape[-1] + 3 - decoder_input_ids = decoder_input_ids[: batch_size * config.decoder.num_codebooks, :] - return config, input_ids, attention_mask, decoder_input_ids, max_length + max_length = 3 + return config, input_ids, attention_mask, max_length # override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen_melody (input / outputs are # different modalities -> different shapes) @@ -829,18 +705,14 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester model, input_ids, attention_mask, - decoder_input_ids, max_length, output_scores=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, ): - logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], - eos_token_id=model.config.eos_token_id, - forced_bos_token_id=model.config.forced_bos_token_id, - forced_eos_token_id=model.config.forced_eos_token_id, max_length=max_length, ) @@ -859,34 +731,17 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester **model_kwargs, ) - with torch.no_grad(): - model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - output_greedy = model.greedy_search( - decoder_input_ids, - max_length=max_length, - logits_processor=logits_processor, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, - # Ignore copy - **model_kwargs, - ) - return output_greedy, output_generate + return output_generate # override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen_melody (input / outputs are # different modalities -> different shapes) - # Ignore copy def _sample_generate( self, model, input_ids, attention_mask, - decoder_input_ids, max_length, num_return_sequences, - logits_processor, - logits_warper, logits_warper_kwargs, process_kwargs, output_scores=False, @@ -912,53 +767,31 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester **model_kwargs, ) - torch.manual_seed(0) - - # prevent flaky generation test failures - logits_processor.append(InfNanRemoveLogitsProcessor()) - - with torch.no_grad(): - model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - output_sample = model.sample( - decoder_input_ids.repeat_interleave(num_return_sequences, dim=0), - max_length=max_length, - logits_processor=logits_processor, - logits_warper=logits_warper, - output_scores=output_scores, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict_in_generate=return_dict_in_generate, - **model_kwargs, - ) - - return output_sample, output_generate + return output_generate @staticmethod - def _get_logits_processor_and_kwargs( + def _get_logits_processor_and_warper_kwargs( input_length, - eos_token_id, forced_bos_token_id=None, forced_eos_token_id=None, max_length=None, - diversity_penalty=None, ): process_kwargs = { "min_length": input_length + 1 if max_length is None else max_length - 1, } - logits_processor = LogitsProcessorList() - return process_kwargs, logits_processor + warper_kwargs = {} + return process_kwargs, warper_kwargs def test_greedy_generate_dict_outputs(self): for model_class in self.greedy_sample_model_classes: # disable cache - config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config.use_cache = False model = model_class(config).to(torch_device).eval() - output_greedy, output_generate = self._greedy_generate( + output_generate = self._greedy_generate( model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), - decoder_input_ids=decoder_input_ids, max_length=max_length, output_scores=True, output_hidden_states=True, @@ -966,7 +799,6 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester return_dict_in_generate=True, ) - self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) self.assertNotIn(config.pad_token_id, output_generate) @@ -974,16 +806,15 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester def test_greedy_generate_dict_outputs_use_cache(self): for model_class in self.greedy_sample_model_classes: # enable cache - config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() - output_greedy, output_generate = self._greedy_generate( + output_generate = self._greedy_generate( model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), - decoder_input_ids=decoder_input_ids, max_length=max_length, output_scores=True, output_hidden_states=True, @@ -991,64 +822,48 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester return_dict_in_generate=True, ) - self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) def test_sample_generate(self): for model_class in self.greedy_sample_model_classes: - config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() - process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], - model.config.eos_token_id, - forced_bos_token_id=model.config.forced_bos_token_id, - forced_eos_token_id=model.config.forced_eos_token_id, max_length=max_length, ) - logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2) # check `generate()` and `sample()` are equal - output_sample, output_generate = self._sample_generate( + output_generate = self._sample_generate( model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), - decoder_input_ids=decoder_input_ids, max_length=max_length, num_return_sequences=1, - logits_processor=logits_processor, - logits_warper=logits_warper, logits_warper_kwargs=logits_warper_kwargs, process_kwargs=process_kwargs, ) - self.assertIsInstance(output_sample, torch.Tensor) self.assertIsInstance(output_generate, torch.Tensor) def test_sample_generate_dict_output(self): for model_class in self.greedy_sample_model_classes: # disable cache - config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config.use_cache = False model = model_class(config).to(torch_device).eval() - process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], - model.config.eos_token_id, - forced_bos_token_id=model.config.forced_bos_token_id, - forced_eos_token_id=model.config.forced_eos_token_id, max_length=max_length, ) - logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) - output_sample, output_generate = self._sample_generate( + output_generate = self._sample_generate( model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), - decoder_input_ids=decoder_input_ids, max_length=max_length, num_return_sequences=3, - logits_processor=logits_processor, - logits_warper=logits_warper, logits_warper_kwargs=logits_warper_kwargs, process_kwargs=process_kwargs, output_scores=True, @@ -1057,11 +872,10 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester return_dict_in_generate=True, ) - self.assertIsInstance(output_sample, GenerateDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) def test_generate_without_input_ids(self): - config, _, _, _, max_length = self._get_input_ids_and_config() + config, _, _, max_length = self._get_input_ids_and_config() # if no bos token id => cannot generate from None if config.bos_token_id is None: @@ -1090,15 +904,14 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester def test_greedy_generate_stereo_outputs(self): for model_class in self.greedy_sample_model_classes: - config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config.audio_channels = 2 model = model_class(config).to(torch_device).eval() - output_greedy, output_generate = self._greedy_generate( + output_generate = self._greedy_generate( model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), - decoder_input_ids=decoder_input_ids, max_length=max_length, output_scores=True, output_hidden_states=True, @@ -1106,7 +919,6 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester return_dict_in_generate=True, ) - self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) self.assertNotIn(config.pad_token_id, output_generate)