From b1cd48740ea52535926631e9e42beee4ba8d8740 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Fri, 19 Apr 2024 21:32:52 +0500 Subject: [PATCH] Do not remove half seq length in generation tests (#30016) * remove seq length from generation tests * style and quality * [test_all] & PR suggestion Co-authored-by: Joao Gante * Update tests/generation/test_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * [test all] remove unused variables --------- Co-authored-by: Joao Gante Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- tests/generation/test_utils.py | 249 ++++++++---------- .../test_modeling_bigbird_pegasus.py | 4 +- tests/models/led/test_modeling_led.py | 14 + tests/models/longt5/test_modeling_longt5.py | 4 +- .../models/musicgen/test_modeling_musicgen.py | 72 ++--- .../test_modeling_musicgen_melody.py | 72 ++--- .../models/reformer/test_modeling_reformer.py | 12 + .../test_modeling_speech_to_text.py | 4 +- tests/models/whisper/test_modeling_whisper.py | 4 +- tests/models/xlnet/test_modeling_xlnet.py | 6 +- 10 files changed, 180 insertions(+), 261 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 8382273bef..a8edd33273 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -82,43 +82,35 @@ class GenerationTesterMixin: model_tester = None all_generative_model_classes = () input_name = "input_ids" + max_new_tokens = 3 def _get_input_ids_and_config(self, batch_size=2): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() input_ids = inputs_dict[self.input_name] - # cut to half length & take max batch_size 3 - sequence_length = input_ids.shape[-1] // 2 - input_ids = input_ids[:batch_size, :sequence_length] + input_ids = input_ids[:batch_size] - # generate max 3 tokens - if config.is_encoder_decoder: - max_length = 4 - else: - max_length = input_ids.shape[-1] + 3 if config.eos_token_id is not None and config.pad_token_id is None: # hack to allow generate for models such as GPT2 as is done in `generate()` if isinstance(config.eos_token_id, int): config.eos_token_id = [config.eos_token_id] config.pad_token_id = config.eos_token_id[0] - attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:batch_size, :sequence_length] + attention_mask = torch.ones_like(input_ids, dtype=torch.long) # It is important set set the eos_token_id to None to ensure that no sequences # shorter than `max_length` can be generated config.eos_token_id = None config.forced_eos_token_id = None - return config, input_ids, attention_mask, max_length + return config, input_ids, attention_mask @staticmethod def _get_logits_processor_and_warper_kwargs( input_length, forced_bos_token_id=None, forced_eos_token_id=None, - max_length=None, ): process_kwargs = { - "min_length": input_length + 1 if max_length is None else max_length - 1, "bad_words_ids": [[1, 0]], "repetition_penalty": 1.2, "remove_invalid_values": True, @@ -185,7 +177,6 @@ class GenerationTesterMixin: model, input_ids, attention_mask, - max_length, output_scores=False, output_logits=False, output_attentions=False, @@ -196,7 +187,6 @@ class GenerationTesterMixin: input_ids.shape[-1], forced_bos_token_id=model.config.forced_bos_token_id, forced_eos_token_id=model.config.forced_eos_token_id, - max_length=max_length, ) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} @@ -204,7 +194,7 @@ class GenerationTesterMixin: input_ids, do_sample=False, num_beams=1, - max_length=max_length, + max_new_tokens=self.max_new_tokens, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_scores=output_scores, @@ -221,7 +211,6 @@ class GenerationTesterMixin: model, input_ids, attention_mask, - max_length, num_return_sequences, logits_warper_kwargs, process_kwargs, @@ -237,7 +226,7 @@ class GenerationTesterMixin: input_ids, do_sample=True, num_beams=1, - max_length=max_length, + max_new_tokens=self.max_new_tokens, num_return_sequences=num_return_sequences, output_scores=output_scores, output_logits=output_logits, @@ -256,7 +245,6 @@ class GenerationTesterMixin: model, input_ids, attention_mask, - max_length, beam_kwargs, logits_process_kwargs, output_scores=False, @@ -269,7 +257,7 @@ class GenerationTesterMixin: output_generate = model.generate( input_ids, do_sample=False, - max_length=max_length, + max_new_tokens=self.max_new_tokens, output_scores=output_scores, output_logits=output_logits, output_attentions=output_attentions, @@ -287,7 +275,6 @@ class GenerationTesterMixin: model, input_ids, attention_mask, - max_length, beam_kwargs, logits_warper_kwargs, output_scores=False, @@ -301,7 +288,7 @@ class GenerationTesterMixin: output_generate = model.generate( input_ids, do_sample=True, - max_length=max_length, + max_new_tokens=self.max_new_tokens, output_scores=output_scores, output_logits=output_logits, output_attentions=output_attentions, @@ -319,7 +306,6 @@ class GenerationTesterMixin: model, input_ids, attention_mask, - max_length, beam_kwargs, logits_process_kwargs, output_scores=False, @@ -332,7 +318,7 @@ class GenerationTesterMixin: output_generate = model.generate( input_ids, do_sample=False, - max_length=max_length, + max_new_tokens=self.max_new_tokens, output_scores=output_scores, output_logits=output_logits, output_attentions=output_attentions, @@ -350,7 +336,6 @@ class GenerationTesterMixin: model, input_ids, attention_mask, - max_length, constraints, beam_kwargs, logits_process_kwargs, @@ -364,7 +349,7 @@ class GenerationTesterMixin: output_generate = model.generate( input_ids, do_sample=False, - max_length=max_length, + max_new_tokens=self.max_new_tokens, output_scores=output_scores, output_logits=output_logits, output_attentions=output_attentions, @@ -383,7 +368,6 @@ class GenerationTesterMixin: model, input_ids, attention_mask, - max_length, output_scores=False, output_logits=False, output_attentions=False, @@ -399,7 +383,6 @@ class GenerationTesterMixin: input_ids.shape[-1], forced_bos_token_id=model.config.forced_bos_token_id, forced_eos_token_id=model.config.forced_eos_token_id, - max_length=max_length, ) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} @@ -407,7 +390,7 @@ class GenerationTesterMixin: input_ids, do_sample=False, num_beams=1, - max_length=max_length, + max_new_tokens=self.max_new_tokens, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_scores=output_scores, @@ -422,18 +405,19 @@ class GenerationTesterMixin: def test_greedy_generate(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() - output_generate = self._greedy_generate( - model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length - ) + output_generate = self._greedy_generate(model=model, input_ids=input_ids, attention_mask=attention_mask) - self.assertTrue(output_generate.shape[-1] == max_length) + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) def test_greedy_generate_dict_outputs(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() config.use_cache = False model = model_class(config).to(torch_device).eval() @@ -441,7 +425,6 @@ class GenerationTesterMixin: model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, output_scores=True, output_logits=True, output_hidden_states=True, @@ -450,20 +433,21 @@ class GenerationTesterMixin: ) if model.config.is_encoder_decoder: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) # Retrocompatibility check self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput) else: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) # Retrocompatibility check self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput) - self.assertTrue(output_generate.sequences.shape[-1] == max_length) self._check_outputs(output_generate, input_ids, model.config) def test_greedy_generate_dict_outputs_use_cache(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() if not hasattr(config, "use_cache"): self.skipTest("This model doesn't support caching") @@ -475,7 +459,6 @@ class GenerationTesterMixin: model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, output_scores=True, output_logits=True, output_hidden_states=True, @@ -483,57 +466,54 @@ class GenerationTesterMixin: return_dict_in_generate=True, ) - self.assertTrue(output_generate.sequences.shape[-1] == max_length) + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) self._check_outputs(output_generate, input_ids, model.config, use_cache=True) def test_sample_generate(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() - if model.config.is_encoder_decoder: - max_length = 4 - process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], forced_bos_token_id=model.config.forced_bos_token_id, forced_eos_token_id=model.config.forced_eos_token_id, - max_length=max_length, ) output_generate = self._sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, num_return_sequences=1, logits_warper_kwargs=logits_warper_kwargs, process_kwargs=process_kwargs, ) - self.assertTrue(output_generate.shape[-1] == max_length) + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) def test_sample_generate_dict_output(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() config.use_cache = False model = model_class(config).to(torch_device).eval() - if model.config.is_encoder_decoder: - max_length = 4 process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], forced_bos_token_id=model.config.forced_bos_token_id, forced_eos_token_id=model.config.forced_eos_token_id, - max_length=max_length, ) output_generate = self._sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, num_return_sequences=2, logits_warper_kwargs=logits_warper_kwargs, process_kwargs=process_kwargs, @@ -545,30 +525,28 @@ class GenerationTesterMixin: ) if model.config.is_encoder_decoder: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) # Retrocompatibility check self.assertIsInstance(output_generate, SampleEncoderDecoderOutput) else: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) # Retrocompatibility check self.assertIsInstance(output_generate, SampleDecoderOnlyOutput) - self.assertTrue(output_generate.sequences.shape[-1] == max_length) self._check_outputs(output_generate, input_ids, model.config, num_return_sequences=2) def test_beam_search_generate(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() - if model.config.is_encoder_decoder: - max_length = 4 logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], config.forced_bos_token_id, config.forced_eos_token_id, - max_length, ) beam_kwargs = self._get_beam_kwargs() @@ -576,36 +554,33 @@ class GenerationTesterMixin: model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, beam_kwargs=beam_kwargs, logits_process_kwargs=logits_process_kwargs, ) - self.assertTrue(output_generate.shape[-1] == max_length) + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) def test_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() # disable cache config.use_cache = False model = model_class(config).to(torch_device).eval() - if model.config.is_encoder_decoder: - max_length = 4 - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], config.forced_bos_token_id, config.forced_eos_token_id, - max_length, ) beam_kwargs = self._get_beam_kwargs() output_generate = self._beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, beam_kwargs=beam_kwargs, logits_process_kwargs=logits_process_kwargs, output_scores=True, @@ -615,15 +590,16 @@ class GenerationTesterMixin: return_dict_in_generate=True, ) if model.config.is_encoder_decoder: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) # Retrocompatibility check self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) else: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) # Retrocompatibility check self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) - self.assertTrue(output_generate.sequences.shape[-1] == max_length) self._check_outputs( output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"] ) @@ -631,20 +607,16 @@ class GenerationTesterMixin: def test_beam_search_generate_dict_outputs_use_cache(self): for model_class in self.all_generative_model_classes: # enable cache - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() if not hasattr(config, "use_cache"): self.skipTest("This model doesn't support caching") model = model_class(config).to(torch_device).eval() - if model.config.is_encoder_decoder: - max_length = 4 - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], config.forced_bos_token_id, config.forced_eos_token_id, - max_length, ) beam_kwargs = self._get_beam_kwargs() @@ -656,7 +628,6 @@ class GenerationTesterMixin: model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, beam_kwargs=beam_kwargs, logits_process_kwargs=logits_process_kwargs, output_scores=True, @@ -666,7 +637,10 @@ class GenerationTesterMixin: return_dict_in_generate=True, ) - self.assertTrue(output_generate.sequences.shape[-1] == max_length) + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) self._check_outputs( output_generate, input_ids, model.config, use_cache=True, num_return_sequences=beam_kwargs["num_beams"] ) @@ -681,7 +655,7 @@ class GenerationTesterMixin: if model_class._no_split_modules is None: continue - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() model = model_class(config).eval() with tempfile.TemporaryDirectory() as tmp_dir: @@ -691,32 +665,32 @@ class GenerationTesterMixin: new_model.generate( input_ids, attention_mask=attention_mask, - max_length=max_length, + max_new_tokens=self.max_new_tokens, num_beams=2, ) def test_beam_sample_generate(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() _, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1]) model = model_class(config).to(torch_device).eval() - - if model.config.is_encoder_decoder: - max_length = 4 beam_kwargs = self._get_beam_kwargs() output_generate = self._beam_sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, beam_kwargs=beam_kwargs, logits_warper_kwargs=logits_warper_kwargs, ) - self.assertTrue(output_generate.shape[-1] == max_length) + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + if "inputs_embeds" in set(inspect.signature(model.prepare_inputs_for_generation).parameters): input_embeds = model.get_input_embeddings()(input_ids) beam_kwargs.update({"inputs_embeds": input_embeds}) @@ -724,7 +698,6 @@ class GenerationTesterMixin: model=model, input_ids=None, attention_mask=attention_mask, - max_length=max_length, beam_kwargs=beam_kwargs, logits_warper_kwargs=logits_warper_kwargs, ) @@ -733,23 +706,19 @@ class GenerationTesterMixin: def test_beam_sample_generate_dict_output(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() # disable cache config.use_cache = False model = model_class(config).to(torch_device).eval() _, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1]) - - if model.config.is_encoder_decoder: - max_length = 4 beam_kwargs = self._get_beam_kwargs() output_generate = self._beam_sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, beam_kwargs=beam_kwargs, logits_warper_kwargs=logits_warper_kwargs, output_scores=True, @@ -760,21 +729,22 @@ class GenerationTesterMixin: ) if model.config.is_encoder_decoder: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) # Retrocompatibility check self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput) else: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) # Retrocompatibility check self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput) - self.assertTrue(output_generate.sequences.shape[-1] == max_length) self._check_outputs( output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"] ) def test_generate_without_input_ids(self): - config, _, _, max_length = self._get_input_ids_and_config() + config, _, _ = self._get_input_ids_and_config() # if no bos token id => cannot generate from None if config.bos_token_id is None: @@ -788,22 +758,20 @@ class GenerationTesterMixin: model = model_class(config).to(torch_device) model.eval() - output_ids_generate = model.generate(do_sample=False, max_length=max_length, remove_invalid_values=True) + output_ids_generate = model.generate( + do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True + ) self.assertIsNotNone(output_ids_generate) def test_group_beam_search_generate(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() - if model.config.is_encoder_decoder: - max_length = 4 - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], config.forced_bos_token_id, config.forced_eos_token_id, - max_length, ) # check `generate()` and `group_beam_search()` are equal @@ -812,11 +780,13 @@ class GenerationTesterMixin: model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, beam_kwargs=beam_kwargs, logits_process_kwargs=logits_process_kwargs, ) - self.assertTrue(output_generate.shape[-1] == max_length) + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) # check `group_beam_search` for higher than 1 `num_return_sequences` num_return_sequences = 2 @@ -825,26 +795,24 @@ class GenerationTesterMixin: model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, beam_kwargs=beam_kwargs, logits_process_kwargs=logits_process_kwargs, ) - self.assertTrue(output_generate.shape[-1] == max_length) + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) def test_group_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() config.use_cache = False model = model_class(config).to(torch_device).eval() - if model.config.is_encoder_decoder: - max_length = 4 - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], config.forced_bos_token_id, config.forced_eos_token_id, - max_length, ) beam_kwargs = self._get_diverse_beam_kwargs() @@ -852,7 +820,6 @@ class GenerationTesterMixin: model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, beam_kwargs=beam_kwargs, logits_process_kwargs=logits_process_kwargs, output_scores=True, @@ -862,15 +829,16 @@ class GenerationTesterMixin: return_dict_in_generate=True, ) if model.config.is_encoder_decoder: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) # Retrocompatibility check self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) else: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) # Retrocompatibility check self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) - self.assertTrue(output_generate.sequences.shape[-1] == max_length) self._check_outputs( output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"] ) @@ -879,16 +847,14 @@ class GenerationTesterMixin: @is_flaky() def test_constrained_beam_search_generate(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() - max_length = 20 logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], config.forced_bos_token_id, config.forced_eos_token_id, - max_length, ) # Sample constraints @@ -905,12 +871,16 @@ class GenerationTesterMixin: model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, constraints=constraints, beam_kwargs=beam_kwargs, logits_process_kwargs=logits_process_kwargs, ) - self.assertTrue(output_generate.shape[-1] == max_length) + + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + for generation_output in output_generate: self._check_sequence_inside_sequence(force_tokens, generation_output) @@ -921,39 +891,37 @@ class GenerationTesterMixin: PhrasalConstraint(force_tokens), ] - max_length = 20 beam_kwargs = self._get_constrained_beam_kwargs(num_return_sequences=2) output_generate = self._constrained_beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, constraints=constraints, beam_kwargs=beam_kwargs, logits_process_kwargs=logits_process_kwargs, ) - self.assertTrue(output_generate.shape[-1] == max_length) + + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) for generation_output in output_generate: self._check_sequence_inside_sequence(force_tokens, generation_output) def test_constrained_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() # disable cache config.use_cache = False model = model_class(config).to(torch_device).eval() - if model.config.is_encoder_decoder: - max_length = 20 - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], config.forced_bos_token_id, config.forced_eos_token_id, - max_length, ) # Sample constraints @@ -969,7 +937,6 @@ class GenerationTesterMixin: model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, constraints=constraints, beam_kwargs=beam_kwargs, logits_process_kwargs=logits_process_kwargs, @@ -981,15 +948,16 @@ class GenerationTesterMixin: ) if model.config.is_encoder_decoder: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) # Retrocompatibility check self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) else: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) # Retrocompatibility check self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) - self.assertTrue(output_generate.sequences.shape[-1] == max_length) self._check_outputs( output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"] ) @@ -1000,7 +968,7 @@ class GenerationTesterMixin: if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): self.skipTest("Won't fix: old model with different cache format") - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): @@ -1011,9 +979,12 @@ class GenerationTesterMixin: # test old generation output for backwards compatibility model = model_class(config).to(torch_device).eval() output_generate = self._contrastive_generate( - model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length + model=model, input_ids=input_ids, attention_mask=attention_mask ) - self.assertTrue(output_generate.shape[-1] == max_length) + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) def test_contrastive_generate_dict_outputs_use_cache(self): for model_class in self.all_generative_model_classes: @@ -1021,7 +992,7 @@ class GenerationTesterMixin: if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): self.skipTest("Won't fix: old model with different cache format") - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): @@ -1034,7 +1005,6 @@ class GenerationTesterMixin: model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, output_scores=True, output_logits=True, output_hidden_states=True, @@ -1042,7 +1012,10 @@ class GenerationTesterMixin: return_dict_in_generate=True, ) - self.assertTrue(output_generate.sequences.shape[-1] == max_length) + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) self._check_outputs(output_generate, input_ids, model.config, use_cache=True) def test_contrastive_generate_low_memory(self): @@ -1053,7 +1026,7 @@ class GenerationTesterMixin: if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode", "jamba"]): self.skipTest("TODO: fix me") - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) + config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1) # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): @@ -1070,7 +1043,7 @@ class GenerationTesterMixin: top_k=4, penalty_alpha=0.6, low_memory=True, - max_length=max_length, + max_new_tokens=self.max_new_tokens, attention_mask=attention_mask, ) @@ -1079,7 +1052,7 @@ class GenerationTesterMixin: top_k=4, penalty_alpha=0.6, low_memory=False, - max_length=max_length, + max_new_tokens=self.max_new_tokens, attention_mask=attention_mask, ) self.assertListEqual(low_output.tolist(), high_output.tolist()) @@ -1102,7 +1075,7 @@ class GenerationTesterMixin: ] ): self.skipTest("May fix in the future: need model-specific fixes") - config, input_ids, _, _ = self._get_input_ids_and_config(batch_size=2) + config, input_ids, _ = self._get_input_ids_and_config(batch_size=2) # batch_size=1 is ok, but batch_size>1 will cause non-identical output config.use_cache = True @@ -1150,7 +1123,7 @@ class GenerationTesterMixin: self.skipTest("May fix in the future: need model-specific fixes") # enable cache - config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1) + config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1) # NOTE: assisted generation only works with cache on at the moment. if not hasattr(config, "use_cache"): @@ -1213,7 +1186,7 @@ class GenerationTesterMixin: self.skipTest("May fix in the future: need model-specific fixes") # enable cache - config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1) + config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1) # NOTE: assisted generation only works with cache on at the moment. if not hasattr(config, "use_cache"): @@ -1273,7 +1246,7 @@ class GenerationTesterMixin: self.skipTest("May fix in the future: need model-specific fixes") # enable cache - config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1) + config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1) # NOTE: assisted generation only works with cache on at the moment. if not hasattr(config, "use_cache"): @@ -1311,7 +1284,7 @@ class GenerationTesterMixin: """Test designed for encoder-decoder models to ensure the attention head masking is used.""" attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() # We want to test only encoder-decoder models if not config.is_encoder_decoder: continue @@ -1358,7 +1331,7 @@ class GenerationTesterMixin: # - The model must be a decoder-only architecture (encoder-based architectures use right-padding) decoder_only_classes = [] for model_class in self.all_generative_model_classes: - config, _, _, _ = self._get_input_ids_and_config() + config, _, _ = self._get_input_ids_and_config() if config.is_encoder_decoder: continue else: @@ -1391,7 +1364,7 @@ class GenerationTesterMixin: return model_kwargs for model_class in decoder_only_classes: - config, input_ids, attention_mask, _ = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() signature = inspect.signature(model.forward).parameters.keys() @@ -1485,7 +1458,7 @@ class GenerationTesterMixin: # When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids` # if fails, you should probably update the `prepare_inputs_for_generation` function for model_class in self.all_generative_model_classes: - config, input_ids, _, _ = self._get_input_ids_and_config() + config, input_ids, _ = self._get_input_ids_and_config() # Ignore: # a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids, @@ -1616,7 +1589,7 @@ class GenerationTesterMixin: if not model_class._supports_cache_class: self.skipTest("This model does not support the new cache format") - config, input_ids, attention_mask, _ = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() config.use_cache = True config.is_decoder = True diff --git a/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py b/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py index 96e7ce639f..82b7cb574d 100644 --- a/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py +++ b/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py @@ -299,12 +299,10 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT input_ids = input_ids[:batch_size, :sequence_length] attention_mask = attention_mask[:batch_size, :sequence_length] - # generate max 3 tokens - max_length = input_ids.shape[-1] + 3 if config.eos_token_id is not None and config.pad_token_id is None: # hack to allow generate for models such as GPT2 as is done in `generate()` config.pad_token_id = config.eos_token_id - return config, input_ids, attention_mask, max_length + return config, input_ids, attention_mask def setUp(self): self.model_tester = BigBirdPegasusModelTester(self) diff --git a/tests/models/led/test_modeling_led.py b/tests/models/led/test_modeling_led.py index 120308db90..10d944c496 100644 --- a/tests/models/led/test_modeling_led.py +++ b/tests/models/led/test_modeling_led.py @@ -457,6 +457,20 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ], ) + def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length): + # overwrite because LED does not have (bs, num_heads, seq_len, seq_len) shape + encoder_expected_shape = ( + batch_size, + config.num_attention_heads, + seq_length, + self.model_tester.attention_window // 2 * 2 + 1, + ) + self.assertIsInstance(attentions, tuple) + self.assertListEqual( + [layer_attentions.shape for layer_attentions in attentions], + [encoder_expected_shape] * len(attentions), + ) + def assert_tensors_close(a, b, atol=1e-12, prefix=""): """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error.""" diff --git a/tests/models/longt5/test_modeling_longt5.py b/tests/models/longt5/test_modeling_longt5.py index c65af001e1..42efd5f01e 100644 --- a/tests/models/longt5/test_modeling_longt5.py +++ b/tests/models/longt5/test_modeling_longt5.py @@ -752,7 +752,7 @@ class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length): block_len = getattr(self.model_tester, "block_len", None) - encoder_expected_shape = (batch_size, 1, config.num_attention_heads, block_len, 3 * block_len) + encoder_expected_shape = (batch_size, 2, config.num_attention_heads, block_len, 3 * block_len) self.assertIsInstance(attentions, tuple) self.assertListEqual( [layer_attentions.shape for layer_attentions in attentions], @@ -885,7 +885,7 @@ class LongT5TGlobalModelTest(LongT5ModelTest): global_seq_length = seq_length // global_block_size encoder_expected_shape = ( batch_size, - 1, + 2, config.num_attention_heads, block_len, 3 * block_len + global_seq_length, diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index df1df64c9c..dff8a6f6fe 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -245,34 +245,28 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste sequence_length = input_ids.shape[-1] input_ids = input_ids[: batch_size * config.num_codebooks, :] - # generate max 3 tokens - max_length = input_ids.shape[-1] + 3 attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long) - return config, input_ids, attention_mask, max_length + return config, input_ids, attention_mask @staticmethod def _get_logits_processor_and_warper_kwargs( input_length, forced_bos_token_id=None, forced_eos_token_id=None, - max_length=None, ): - process_kwargs = { - "min_length": input_length + 1 if max_length is None else max_length - 1, - } + process_kwargs = {} warper_kwargs = {} return process_kwargs, warper_kwargs def test_greedy_generate_stereo_outputs(self): for model_class in self.greedy_sample_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() config.audio_channels = 2 model = model_class(config).to(torch_device).eval() output_generate = self._greedy_generate( model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), - max_length=max_length, output_scores=True, output_hidden_states=True, output_attentions=True, @@ -1327,9 +1321,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, input_ids = input_ids[:batch_size, :] attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long) - # generate max 3 tokens - max_length = 3 - return config, input_ids, attention_mask, max_length + return config, input_ids, attention_mask # override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are # different modalities -> different shapes) @@ -1338,29 +1330,22 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, model, input_ids, attention_mask, - max_length, output_scores=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, ): - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - max_length=max_length, - ) - model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, do_sample=False, num_beams=1, - max_length=max_length, + max_new_tokens=self.max_new_tokens, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, remove_invalid_values=True, - **logits_process_kwargs, **model_kwargs, ) @@ -1373,10 +1358,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, model, input_ids, attention_mask, - max_length, num_return_sequences, - logits_warper_kwargs, - process_kwargs, output_scores=False, output_attentions=False, output_hidden_states=False, @@ -1388,15 +1370,13 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, input_ids, do_sample=True, num_beams=1, - max_length=max_length, + max_new_tokens=self.max_new_tokens, num_return_sequences=num_return_sequences, output_scores=output_scores, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, remove_invalid_values=True, - **logits_warper_kwargs, - **process_kwargs, **model_kwargs, ) @@ -1407,25 +1387,21 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, input_length, forced_bos_token_id=None, forced_eos_token_id=None, - max_length=None, ): - process_kwargs = { - "min_length": input_length + 1 if max_length is None else max_length - 1, - } + process_kwargs = {} warper_kwargs = {} return process_kwargs, warper_kwargs def test_greedy_generate_dict_outputs(self): for model_class in self.greedy_sample_model_classes: # disable cache - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() config.use_cache = False model = model_class(config).to(torch_device).eval() output_generate = self._greedy_generate( model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), - max_length=max_length, output_scores=True, output_hidden_states=True, output_attentions=True, @@ -1439,7 +1415,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, def test_greedy_generate_dict_outputs_use_cache(self): for model_class in self.greedy_sample_model_classes: # enable cache - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() config.use_cache = True config.is_decoder = True @@ -1448,7 +1424,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), - max_length=max_length, output_scores=True, output_hidden_states=True, output_attentions=True, @@ -1459,46 +1434,30 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, def test_sample_generate(self): for model_class in self.greedy_sample_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() - process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - max_length=max_length, - ) - # check `generate()` and `sample()` are equal output_generate = self._sample_generate( model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), - max_length=max_length, num_return_sequences=1, - logits_warper_kwargs=logits_warper_kwargs, - process_kwargs=process_kwargs, ) self.assertIsInstance(output_generate, torch.Tensor) def test_sample_generate_dict_output(self): for model_class in self.greedy_sample_model_classes: # disable cache - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() config.use_cache = False model = model_class(config).to(torch_device).eval() - process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - max_length=max_length, - ) - output_generate = self._sample_generate( model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), - max_length=max_length, num_return_sequences=3, - logits_warper_kwargs=logits_warper_kwargs, - process_kwargs=process_kwargs, output_scores=True, output_hidden_states=True, output_attentions=True, @@ -1508,7 +1467,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) def test_generate_without_input_ids(self): - config, _, _, max_length = self._get_input_ids_and_config() + config, _, _ = self._get_input_ids_and_config() # if no bos token id => cannot generate from None if config.bos_token_id is None: @@ -1518,7 +1477,9 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, model = model_class(config).to(torch_device) model.eval() - output_ids_generate = model.generate(do_sample=False, max_length=max_length, remove_invalid_values=True) + output_ids_generate = model.generate( + do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True + ) self.assertIsNotNone(output_ids_generate) @require_torch_fp16 @@ -1537,7 +1498,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, def test_greedy_generate_stereo_outputs(self): for model_class in self.greedy_sample_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() config.audio_channels = 2 model = model_class(config).to(torch_device).eval() @@ -1545,7 +1506,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), - max_length=max_length, output_scores=True, output_hidden_states=True, output_attentions=True, diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py index 667958a251..9931bcb32a 100644 --- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py +++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py @@ -246,34 +246,28 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes sequence_length = input_ids.shape[-1] input_ids = input_ids[: batch_size * config.num_codebooks, :] - # generate max 3 tokens - max_length = input_ids.shape[-1] + 3 attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long) - return config, input_ids, attention_mask, max_length + return config, input_ids, attention_mask @staticmethod def _get_logits_processor_and_warper_kwargs( input_length, forced_bos_token_id=None, forced_eos_token_id=None, - max_length=None, ): - process_kwargs = { - "min_length": input_length + 1 if max_length is None else max_length - 1, - } + process_kwargs = {} warper_kwargs = {} return process_kwargs, warper_kwargs def test_greedy_generate_stereo_outputs(self): for model_class in self.greedy_sample_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() config.audio_channels = 2 model = model_class(config).to(torch_device).eval() output_generate = self._greedy_generate( model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), - max_length=max_length, output_scores=True, output_hidden_states=True, output_attentions=True, @@ -1309,9 +1303,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester input_ids = input_ids[:batch_size, :] attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long) - # generate max 3 tokens - max_length = 3 - return config, input_ids, attention_mask, max_length + return config, input_ids, attention_mask # override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen_melody (input / outputs are # different modalities -> different shapes) @@ -1320,29 +1312,22 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester model, input_ids, attention_mask, - max_length, output_scores=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, ): - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - max_length=max_length, - ) - model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, do_sample=False, num_beams=1, - max_length=max_length, + max_new_tokens=self.max_new_tokens, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, remove_invalid_values=True, - **logits_process_kwargs, **model_kwargs, ) @@ -1355,10 +1340,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester model, input_ids, attention_mask, - max_length, num_return_sequences, - logits_warper_kwargs, - process_kwargs, output_scores=False, output_attentions=False, output_hidden_states=False, @@ -1370,15 +1352,13 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester input_ids, do_sample=True, num_beams=1, - max_length=max_length, + max_new_tokens=self.max_new_tokens, num_return_sequences=num_return_sequences, output_scores=output_scores, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, remove_invalid_values=True, - **logits_warper_kwargs, - **process_kwargs, **model_kwargs, ) @@ -1389,25 +1369,21 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester input_length, forced_bos_token_id=None, forced_eos_token_id=None, - max_length=None, ): - process_kwargs = { - "min_length": input_length + 1 if max_length is None else max_length - 1, - } + process_kwargs = {} warper_kwargs = {} return process_kwargs, warper_kwargs def test_greedy_generate_dict_outputs(self): for model_class in self.greedy_sample_model_classes: # disable cache - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() config.use_cache = False model = model_class(config).to(torch_device).eval() output_generate = self._greedy_generate( model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), - max_length=max_length, output_scores=True, output_hidden_states=True, output_attentions=True, @@ -1421,7 +1397,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester def test_greedy_generate_dict_outputs_use_cache(self): for model_class in self.greedy_sample_model_classes: # enable cache - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() config.use_cache = True config.is_decoder = True @@ -1430,7 +1406,6 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), - max_length=max_length, output_scores=True, output_hidden_states=True, output_attentions=True, @@ -1441,46 +1416,30 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester def test_sample_generate(self): for model_class in self.greedy_sample_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() - process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - max_length=max_length, - ) - # check `generate()` and `sample()` are equal output_generate = self._sample_generate( model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), - max_length=max_length, num_return_sequences=1, - logits_warper_kwargs=logits_warper_kwargs, - process_kwargs=process_kwargs, ) self.assertIsInstance(output_generate, torch.Tensor) def test_sample_generate_dict_output(self): for model_class in self.greedy_sample_model_classes: # disable cache - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() config.use_cache = False model = model_class(config).to(torch_device).eval() - process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - max_length=max_length, - ) - output_generate = self._sample_generate( model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), - max_length=max_length, num_return_sequences=3, - logits_warper_kwargs=logits_warper_kwargs, - process_kwargs=process_kwargs, output_scores=True, output_hidden_states=True, output_attentions=True, @@ -1490,7 +1449,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) def test_generate_without_input_ids(self): - config, _, _, max_length = self._get_input_ids_and_config() + config, _, _ = self._get_input_ids_and_config() # if no bos token id => cannot generate from None if config.bos_token_id is None: @@ -1500,7 +1459,9 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester model = model_class(config).to(torch_device) model.eval() - output_ids_generate = model.generate(do_sample=False, max_length=max_length, remove_invalid_values=True) + output_ids_generate = model.generate( + do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True + ) self.assertIsNotNone(output_ids_generate) @require_torch_fp16 @@ -1519,7 +1480,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester def test_greedy_generate_stereo_outputs(self): for model_class in self.greedy_sample_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() config.audio_channels = 2 model = model_class(config).to(torch_device).eval() @@ -1527,7 +1488,6 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), - max_length=max_length, output_scores=True, output_hidden_states=True, output_attentions=True, diff --git a/tests/models/reformer/test_modeling_reformer.py b/tests/models/reformer/test_modeling_reformer.py index d3996a31c6..3a33a682d1 100644 --- a/tests/models/reformer/test_modeling_reformer.py +++ b/tests/models/reformer/test_modeling_reformer.py @@ -686,6 +686,18 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod def test_left_padding_compatibility(self): pass + def _get_input_ids_and_config(self, batch_size=2): + # override because overwise we hit max possible seq length for model (4*8=32) + # decreasing the seq_length in tester causes errors for "training_tests", those need exactly max seq length + # NOTE: seq_length has to be multiple of 4, otherwise it fails for other tests + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + input_ids = inputs_dict[self.input_name] + input_ids = input_ids[:batch_size, :16] + attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:batch_size, :16] + config.eos_token_id = None + config.forced_eos_token_id = None + return config, input_ids, attention_mask + @require_torch class ReformerLSHAttnModelTest( diff --git a/tests/models/speech_to_text/test_modeling_speech_to_text.py b/tests/models/speech_to_text/test_modeling_speech_to_text.py index 36a973d99d..f3fc72ab8e 100644 --- a/tests/models/speech_to_text/test_modeling_speech_to_text.py +++ b/tests/models/speech_to_text/test_modeling_speech_to_text.py @@ -285,7 +285,7 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest input_name = "input_features" def _get_input_ids_and_config(self, batch_size=2): - config, input_ids, attention_mask, max_length = GenerationTesterMixin._get_input_ids_and_config(self) + config, input_ids, attention_mask = GenerationTesterMixin._get_input_ids_and_config(self) # `input_ids` is actually `input_features` which is a 3D tensor. # We must overwrite the mask to make it 2D since the original `_get_input_ids_and_config` creates an @@ -294,7 +294,7 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest sequence_length = input_ids.shape[1] attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=attention_mask.device) - return config, input_ids, attention_mask, max_length + return config, input_ids, attention_mask def setUp(self): self.model_tester = Speech2TextModelTester(self) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 6acecb8a48..44b6c1ea74 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -477,13 +477,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi # cut to half length & take max batch_size=batch_size input_ids = input_ids[:batch_size, :, :] - # generate max 3 tokens - max_length = 4 if config.eos_token_id is not None and config.pad_token_id is None: # hack to allow generate for models such as GPT2 as is done in `generate()` config.pad_token_id = config.eos_token_id - return config, input_ids, None, max_length + return config, input_ids, None def test_inputs_embeds(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/xlnet/test_modeling_xlnet.py b/tests/models/xlnet/test_modeling_xlnet.py index ff89a9aca3..e2c0f6d7e7 100644 --- a/tests/models/xlnet/test_modeling_xlnet.py +++ b/tests/models/xlnet/test_modeling_xlnet.py @@ -646,7 +646,8 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi seq_len = 1 else: # for first item dummy PAD token is appended so need one more - seq_len = (min_length + 1) if idx == 0 else min_length + # else offset+dummy_token when using cache + seq_len = (min_length + 1) if idx == 0 else 3 expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) self.assertEqual(layer_hidden_states.shape, expected_shape) @@ -665,8 +666,11 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi tgt_len = min_length # for first item dummy PAD token is appended so need one more + # every token after consists of offset+dummy_token length when using cache if idx == 0: tgt_len += 1 + else: + tgt_len = 3 src_len = min_length + idx + 1