From f292940f18013cd730ee98ae46b8f2e91ff57201 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Tue, 19 Mar 2024 22:03:31 +0500 Subject: [PATCH] Clean-up generation tests after moving methods to private (#29582) * clean-up tests * refine comments * fix musicgen tests * make style * remove slow decorator from a test * more clean-up * fix other failing tests --- tests/generation/test_utils.py | 995 +++--------------- .../models/musicgen/test_modeling_musicgen.py | 166 +-- 2 files changed, 156 insertions(+), 1005 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index bd3bbe7c60..8d7d83759c 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -59,33 +59,21 @@ if is_torch_available(): BeamSampleEncoderDecoderOutput, BeamSearchDecoderOnlyOutput, BeamSearchEncoderDecoderOutput, - BeamSearchScorer, - ConstrainedBeamSearchScorer, DisjunctiveConstraint, - ForcedBOSTokenLogitsProcessor, - ForcedEOSTokenLogitsProcessor, GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput, GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, GreedySearchDecoderOnlyOutput, GreedySearchEncoderDecoderOutput, - HammingDiversityLogitsProcessor, - InfNanRemoveLogitsProcessor, LogitsProcessorList, MaxLengthCriteria, MinLengthLogitsProcessor, - NoBadWordsLogitsProcessor, - NoRepeatNGramLogitsProcessor, PhrasalConstraint, - RepetitionPenaltyLogitsProcessor, SampleDecoderOnlyOutput, SampleEncoderDecoderOutput, StoppingCriteria, StoppingCriteriaList, - TemperatureLogitsWarper, - TopKLogitsWarper, - TopPLogitsWarper, ) from transformers.generation.utils import _speculative_sampling @@ -104,7 +92,10 @@ class GenerationTesterMixin: input_ids = input_ids[:batch_size, :sequence_length] # generate max 3 tokens - max_length = input_ids.shape[-1] + 3 + if config.is_encoder_decoder: + max_length = 4 + else: + max_length = input_ids.shape[-1] + 3 if config.eos_token_id is not None and config.pad_token_id is None: # hack to allow generate for models such as GPT2 as is done in `generate()` if isinstance(config.eos_token_id, int): @@ -112,16 +103,19 @@ class GenerationTesterMixin: config.pad_token_id = config.eos_token_id[0] attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:batch_size, :sequence_length] + # It is important set set the eos_token_id to None to ensure that no sequences + # shorter than `max_length` can be generated + config.eos_token_id = None + config.forced_eos_token_id = None + return config, input_ids, attention_mask, max_length @staticmethod - def _get_logits_processor_and_kwargs( + def _get_logits_processor_and_warper_kwargs( input_length, - eos_token_id, forced_bos_token_id=None, forced_eos_token_id=None, max_length=None, - diversity_penalty=None, ): process_kwargs = { "min_length": input_length + 1 if max_length is None else max_length - 1, @@ -133,78 +127,21 @@ class GenerationTesterMixin: if forced_bos_token_id is None and forced_eos_token_id is None: process_kwargs["no_repeat_ngram_size"] = 2 - # NOTE: the order of operations here should match `generate` for accurate testing - logits_processor = LogitsProcessorList( - ( - [ - HammingDiversityLogitsProcessor(diversity_penalty, num_beams=2, num_beam_groups=2), - ] - if diversity_penalty is not None - else [] - ) - + ( - [ - MinLengthLogitsProcessor(process_kwargs["min_length"], eos_token_id), - ] - if eos_token_id is not None - else [] - ) - + ( - [ - ForcedBOSTokenLogitsProcessor(forced_bos_token_id), - ] - if forced_bos_token_id is not None - else [] - ) - + ( - [ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)] - if forced_eos_token_id is not None - else [] - ) - + [NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id)] - + ( - [NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"])] - if forced_bos_token_id is None and forced_eos_token_id is None - else [] - ) - + [RepetitionPenaltyLogitsProcessor(process_kwargs["repetition_penalty"])] - + [InfNanRemoveLogitsProcessor()] # prevent flaky generation test failures - ) - - return process_kwargs, logits_processor - - @staticmethod - def _get_warper_and_kwargs(num_beams): warp_kwargs = {"top_k": 10, "top_p": 0.7, "temperature": 0.7} - logits_warper = LogitsProcessorList( - [ - TemperatureLogitsWarper(warp_kwargs["temperature"]), - TopKLogitsWarper(top_k=warp_kwargs["top_k"], min_tokens_to_keep=(2 if num_beams > 1 else 1)), - TopPLogitsWarper(top_p=warp_kwargs["top_p"], min_tokens_to_keep=(2 if num_beams > 1 else 1)), - ] - ) - return warp_kwargs, logits_warper + return process_kwargs, warp_kwargs @staticmethod - def _get_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1): + def _get_beam_kwargs(num_return_sequences=1): beam_kwargs = { "early_stopping": False, "length_penalty": 2.0, "num_beams": 2, "num_return_sequences": num_return_sequences, } - beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=beam_kwargs["num_beams"], - device=torch_device, - length_penalty=beam_kwargs["length_penalty"], - do_early_stopping=beam_kwargs["early_stopping"], - num_beam_hyps_to_keep=num_return_sequences, - ) - return beam_kwargs, beam_scorer + return beam_kwargs @staticmethod - def _get_diverse_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1): + def _get_diverse_beam_kwargs(num_return_sequences=1): beam_kwargs = { "early_stopping": False, "length_penalty": 2.0, @@ -213,35 +150,17 @@ class GenerationTesterMixin: "num_beam_groups": 2, # one beam per group "diversity_penalty": 2.0, } - beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=beam_kwargs["num_beams"], - device=torch_device, - length_penalty=beam_kwargs["length_penalty"], - do_early_stopping=beam_kwargs["early_stopping"], - num_beam_hyps_to_keep=num_return_sequences, - num_beam_groups=beam_kwargs["num_beam_groups"], - ) - return beam_kwargs, beam_scorer + return beam_kwargs @staticmethod - def _get_constrained_beam_scorer_and_kwargs(batch_size, max_length, constraints, num_return_sequences=1): + def _get_constrained_beam_kwargs(num_return_sequences=1): beam_kwargs = { "early_stopping": False, "length_penalty": 2.0, "num_beams": num_return_sequences * 4, "num_return_sequences": num_return_sequences, } - beam_scorer = ConstrainedBeamSearchScorer( - batch_size=batch_size, - constraints=constraints, - num_beams=beam_kwargs["num_beams"], - device=torch_device, - length_penalty=beam_kwargs["length_penalty"], - do_early_stopping=beam_kwargs["early_stopping"], - num_beam_hyps_to_keep=num_return_sequences, - ) - return beam_kwargs, beam_scorer + return beam_kwargs @staticmethod def _get_encoder_outputs( @@ -273,17 +192,13 @@ class GenerationTesterMixin: output_hidden_states=False, return_dict_in_generate=False, ): - if model.config.is_encoder_decoder: - max_length = 4 - logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], - eos_token_id=model.config.eos_token_id, forced_bos_token_id=model.config.forced_bos_token_id, forced_eos_token_id=model.config.forced_eos_token_id, max_length=max_length, ) - kwargs = {} model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, @@ -299,31 +214,7 @@ class GenerationTesterMixin: **model_kwargs, ) - if model.config.is_encoder_decoder: - encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( - model, - input_ids, - attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - kwargs["encoder_outputs"] = encoder_outputs - - with torch.no_grad(): - model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - output_greedy = model.greedy_search( - input_ids, - max_length=max_length, - logits_processor=logits_processor, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_scores=output_scores, - output_logits=output_logits, - return_dict_in_generate=return_dict_in_generate, - **kwargs, - **model_kwargs, - ) - return output_greedy, output_generate + return output_generate def _sample_generate( self, @@ -332,8 +223,6 @@ class GenerationTesterMixin: attention_mask, max_length, num_return_sequences, - logits_processor, - logits_warper, logits_warper_kwargs, process_kwargs, output_scores=False, @@ -360,38 +249,7 @@ class GenerationTesterMixin: **model_kwargs, ) - torch.manual_seed(0) - kwargs = {} - if model.config.is_encoder_decoder: - encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( - model, - input_ids, - attention_mask, - num_interleave=num_return_sequences, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - kwargs["encoder_outputs"] = encoder_outputs - elif attention_mask is not None: - attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0) - - with torch.no_grad(): - model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - output_sample = model.sample( - input_ids.repeat_interleave(num_return_sequences, dim=0), - max_length=max_length, - logits_processor=logits_processor, - logits_warper=logits_warper, - output_scores=output_scores, - output_logits=output_logits, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict_in_generate=return_dict_in_generate, - **kwargs, - **model_kwargs, - ) - - return output_sample, output_generate + return output_generate def _beam_search_generate( self, @@ -399,9 +257,7 @@ class GenerationTesterMixin: input_ids, attention_mask, max_length, - beam_scorer, beam_kwargs, - logits_processor, logits_process_kwargs, output_scores=False, output_logits=False, @@ -424,37 +280,7 @@ class GenerationTesterMixin: **model_kwargs, ) - # beam_search does not automatically interleave `batch_size` dim for `num_beams` - kwargs = {} - if model.config.is_encoder_decoder: - encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( - model, - input_ids, - attention_mask, - num_interleave=beam_scorer.num_beams, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - kwargs["encoder_outputs"] = encoder_outputs - elif attention_mask is not None: - attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) - - with torch.no_grad(): - model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - output_beam_search = model.beam_search( - input_ids.repeat_interleave(beam_scorer.num_beams, dim=0), - beam_scorer, - max_length=max_length, - logits_processor=logits_processor, - output_scores=output_scores, - output_logits=output_logits, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict_in_generate=return_dict_in_generate, - **kwargs, - **model_kwargs, - ) - return output_generate, output_beam_search + return output_generate def _beam_sample_generate( self, @@ -462,9 +288,7 @@ class GenerationTesterMixin: input_ids, attention_mask, max_length, - beam_scorer, beam_kwargs, - logits_warper, logits_warper_kwargs, output_scores=False, output_logits=False, @@ -487,44 +311,8 @@ class GenerationTesterMixin: **logits_warper_kwargs, **model_kwargs, ) - # beam_search does not automatically interleave `batch_size` dim for `num_beams` - torch.manual_seed(0) - kwargs = {} - if model.config.is_encoder_decoder: - encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( - model, - input_ids, - attention_mask, - num_interleave=beam_scorer.num_beams, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - kwargs["encoder_outputs"] = encoder_outputs - elif attention_mask is not None: - attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) - # prevent flaky generation test failures - logits_processor = LogitsProcessorList() - logits_processor.append(InfNanRemoveLogitsProcessor()) - - with torch.no_grad(): - model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - output_beam_sample = model.beam_sample( - input_ids.repeat_interleave(beam_scorer.num_beams, dim=0), - beam_scorer, - max_length=max_length, - logits_warper=logits_warper, - logits_processor=logits_processor, - output_scores=output_scores, - output_logits=output_logits, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict_in_generate=return_dict_in_generate, - **kwargs, - **model_kwargs, - ) - - return output_generate, output_beam_sample + return output_generate def _group_beam_search_generate( self, @@ -532,9 +320,7 @@ class GenerationTesterMixin: input_ids, attention_mask, max_length, - beam_scorer, beam_kwargs, - logits_processor, logits_process_kwargs, output_scores=False, output_logits=False, @@ -557,37 +343,7 @@ class GenerationTesterMixin: **model_kwargs, ) - # group_beam_search does not automatically interleave `batch_size` dim for `num_beams` - kwargs = {} - if model.config.is_encoder_decoder: - encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( - model, - input_ids, - attention_mask, - num_interleave=beam_scorer.num_beams, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - kwargs["encoder_outputs"] = encoder_outputs - elif attention_mask is not None: - attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) - - with torch.no_grad(): - model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - output_group_beam_search = model.group_beam_search( - input_ids.repeat_interleave(beam_scorer.num_beams, dim=0), - beam_scorer, - max_length=max_length, - logits_processor=logits_processor, - output_scores=output_scores, - output_logits=output_logits, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict_in_generate=return_dict_in_generate, - **kwargs, - **model_kwargs, - ) - return output_generate, output_group_beam_search + return output_generate def _constrained_beam_search_generate( self, @@ -595,10 +351,8 @@ class GenerationTesterMixin: input_ids, attention_mask, max_length, - constrained_beam_scorer, constraints, beam_kwargs, - logits_processor, logits_process_kwargs, output_scores=False, output_logits=False, @@ -622,37 +376,7 @@ class GenerationTesterMixin: **model_kwargs, ) - # group_beam_search does not automatically interleave `batch_size` dim for `num_beams` - kwargs = {} - if model.config.is_encoder_decoder: - encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( - model, - input_ids, - attention_mask, - num_interleave=constrained_beam_scorer.num_beams, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - kwargs["encoder_outputs"] = encoder_outputs - elif attention_mask is not None: - attention_mask = attention_mask.repeat_interleave(constrained_beam_scorer.num_beams, dim=0) - - with torch.no_grad(): - model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - output_group_beam_search = model.constrained_beam_search( - input_ids.repeat_interleave(constrained_beam_scorer.num_beams, dim=0), - constrained_beam_scorer, - max_length=max_length, - logits_processor=logits_processor, - output_scores=output_scores, - output_logits=output_logits, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict_in_generate=return_dict_in_generate, - **kwargs, - **model_kwargs, - ) - return output_generate, output_group_beam_search + return output_generate def _contrastive_generate( self, @@ -671,17 +395,13 @@ class GenerationTesterMixin: "top_k": 5, } - if model.config.is_encoder_decoder: - max_length = 4 - logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], - eos_token_id=model.config.eos_token_id, forced_bos_token_id=model.config.forced_bos_token_id, forced_eos_token_id=model.config.forced_eos_token_id, max_length=max_length, ) - kwargs = {} model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, @@ -698,52 +418,26 @@ class GenerationTesterMixin: **contrastive_search_kwargs, ) - if model.config.is_encoder_decoder: - encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( - model, - input_ids, - attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - kwargs["encoder_outputs"] = encoder_outputs - - with torch.no_grad(): - model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)]) - output_contrastive = model.contrastive_search( - input_ids, - stopping_criteria=stopping_criteria, - logits_processor=logits_processor, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_scores=output_scores, - output_logits=output_logits, - return_dict_in_generate=return_dict_in_generate, - **kwargs, - **model_kwargs, - **contrastive_search_kwargs, - ) - return output_contrastive, output_generate + return output_generate def test_greedy_generate(self): - # check `generate()` and `greedy_search()` are equal for model_class in self.all_generative_model_classes: config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - # test old generation output for backwards compatibility + model = model_class(config).to(torch_device).eval() - output_greedy, output_generate = self._greedy_generate( + output_generate = self._greedy_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length ) - self.assertListEqual(output_greedy.tolist(), output_generate.tolist()) + + self.assertTrue(output_generate.shape[-1] == max_length) def test_greedy_generate_dict_outputs(self): for model_class in self.all_generative_model_classes: - # disable cache config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config.use_cache = False model = model_class(config).to(torch_device).eval() - output_greedy, output_generate = self._greedy_generate( + output_generate = self._greedy_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, @@ -756,26 +450,19 @@ class GenerationTesterMixin: ) if model.config.is_encoder_decoder: - self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput) self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) # Retrocompatibility check - self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput) self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput) else: - self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) # Retrocompatibility check - self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput) self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput) - self.assertListEqual(output_generate.sequences.tolist(), output_greedy.sequences.tolist()) - - for output in (output_greedy, output_generate): - self._check_outputs(output, input_ids, model.config) + self.assertTrue(output_generate.sequences.shape[-1] == max_length) + self._check_outputs(output_generate, input_ids, model.config) def test_greedy_generate_dict_outputs_use_cache(self): for model_class in self.all_generative_model_classes: - # enable cache config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() if not hasattr(config, "use_cache"): @@ -784,7 +471,7 @@ class GenerationTesterMixin: config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() - output_greedy, output_generate = self._greedy_generate( + output_generate = self._greedy_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, @@ -796,82 +483,58 @@ class GenerationTesterMixin: return_dict_in_generate=True, ) - self.assertListEqual(output_generate.sequences.tolist(), output_greedy.sequences.tolist()) - - for output in (output_greedy, output_generate): - self._check_outputs(output, input_ids, model.config, use_cache=True) + self.assertTrue(output_generate.sequences.shape[-1] == max_length) + self._check_outputs(output_generate, input_ids, model.config, use_cache=True) def test_sample_generate(self): for model_class in self.all_generative_model_classes: config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - model = model_class(config).to(torch_device).eval() + model = model_class(config).to(torch_device).eval() if model.config.is_encoder_decoder: max_length = 4 - process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], - model.config.eos_token_id, forced_bos_token_id=model.config.forced_bos_token_id, forced_eos_token_id=model.config.forced_eos_token_id, max_length=max_length, ) - logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2) - # check `generate()` and `sample()` are equal - output_sample, output_generate = self._sample_generate( + output_generate = self._sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, num_return_sequences=1, - logits_processor=logits_processor, - logits_warper=logits_warper, logits_warper_kwargs=logits_warper_kwargs, process_kwargs=process_kwargs, ) - self.assertListEqual(output_sample.tolist(), output_generate.tolist()) - # check `generate()` and `sample()` yield equal results for `num_return_sequences` - output_sample, output_generate = self._sample_generate( - model=model, - input_ids=input_ids, - attention_mask=attention_mask, - max_length=max_length, - num_return_sequences=3, - logits_processor=logits_processor, - logits_warper=logits_warper, - logits_warper_kwargs=logits_warper_kwargs, - process_kwargs=process_kwargs, - ) - self.assertListEqual(output_sample.tolist(), output_generate.tolist()) + self.assertTrue(output_generate.shape[-1] == max_length) def test_sample_generate_dict_output(self): for model_class in self.all_generative_model_classes: - # disable cache config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config.use_cache = False model = model_class(config).to(torch_device).eval() if model.config.is_encoder_decoder: max_length = 4 - process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], - model.config.eos_token_id, forced_bos_token_id=model.config.forced_bos_token_id, forced_eos_token_id=model.config.forced_eos_token_id, max_length=max_length, ) - logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) - output_sample, output_generate = self._sample_generate( + output_generate = self._sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, num_return_sequences=2, - logits_processor=logits_processor, - logits_warper=logits_warper, logits_warper_kwargs=logits_warper_kwargs, process_kwargs=process_kwargs, output_scores=True, @@ -882,75 +545,43 @@ class GenerationTesterMixin: ) if model.config.is_encoder_decoder: - self.assertIsInstance(output_sample, GenerateEncoderDecoderOutput) self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) # Retrocompatibility check - self.assertIsInstance(output_sample, SampleEncoderDecoderOutput) self.assertIsInstance(output_generate, SampleEncoderDecoderOutput) else: - self.assertIsInstance(output_sample, GenerateDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) # Retrocompatibility check - self.assertIsInstance(output_sample, SampleDecoderOnlyOutput) self.assertIsInstance(output_generate, SampleDecoderOnlyOutput) - self.assertListEqual(output_generate.sequences.tolist(), output_sample.sequences.tolist()) - - for output in (output_sample, output_generate): - self._check_outputs(output, input_ids, model.config, num_return_sequences=2) + self.assertTrue(output_generate.sequences.shape[-1] == max_length) + self._check_outputs(output_generate, input_ids, model.config, num_return_sequences=2) def test_beam_search_generate(self): for model_class in self.all_generative_model_classes: config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - # It is important set set the eos_token_id to None to ensure that no sequences - # shorter than `max_length` can be generated which could lead to flaky circle ci - # failures if the top `num_return_sequences` beams are all shorter than the longest beam - config.eos_token_id = None - config.forced_eos_token_id = None - model = model_class(config).to(torch_device).eval() if model.config.is_encoder_decoder: max_length = 4 - logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], - config.eos_token_id, config.forced_bos_token_id, config.forced_eos_token_id, max_length, ) - beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) + beam_kwargs = self._get_beam_kwargs() - # check `generate()` and `beam_search()` are equal - output_generate, output_beam_search = self._beam_search_generate( + output_generate = self._beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, - beam_scorer=beam_scorer, beam_kwargs=beam_kwargs, logits_process_kwargs=logits_process_kwargs, - logits_processor=logits_processor, ) - self.assertListEqual(output_generate.tolist(), output_beam_search.tolist()) - - if model.config.is_encoder_decoder: - max_length = 4 - beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) - - output_generate, output_beam_search = self._beam_search_generate( - model=model, - input_ids=input_ids, - attention_mask=attention_mask, - max_length=max_length, - beam_scorer=beam_scorer, - beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, - logits_processor=logits_processor, - ) - self.assertListEqual(output_generate.tolist(), output_beam_search.tolist()) + self.assertTrue(output_generate.shape[-1] == max_length) def test_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: @@ -959,33 +590,24 @@ class GenerationTesterMixin: # disable cache config.use_cache = False - # It is important set set the eos_token_id to None to ensure that no sequences - # shorter than `max_length` can be generated which could lead to flaky circle ci - # failures if the top `num_return_sequences` beams are all shorter than the longest beam - config.eos_token_id = None - config.forced_eos_token_id = None - model = model_class(config).to(torch_device).eval() if model.config.is_encoder_decoder: max_length = 4 - logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], - config.eos_token_id, config.forced_bos_token_id, config.forced_eos_token_id, max_length, ) - beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) - output_generate, output_beam_search = self._beam_search_generate( + beam_kwargs = self._get_beam_kwargs() + output_generate = self._beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, - beam_scorer=beam_scorer, beam_kwargs=beam_kwargs, logits_process_kwargs=logits_process_kwargs, - logits_processor=logits_processor, output_scores=True, output_logits=True, output_hidden_states=True, @@ -993,39 +615,24 @@ class GenerationTesterMixin: return_dict_in_generate=True, ) if model.config.is_encoder_decoder: - self.assertIsInstance(output_beam_search, GenerateBeamEncoderDecoderOutput) self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) # Retrocompatibility check - self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput) self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) else: - self.assertIsInstance(output_beam_search, GenerateBeamDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) # Retrocompatibility check - self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput) self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) - self.assertListEqual(output_generate.sequences.tolist(), output_beam_search.sequences.tolist()) - self.assertTrue( - torch.allclose(output_generate["sequences_scores"], output_beam_search["sequences_scores"], atol=1e-3) + self.assertTrue(output_generate.sequences.shape[-1] == max_length) + self._check_outputs( + output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"] ) - self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],)) - self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) - - for output in (output_beam_search, output_generate): - self._check_outputs(output, input_ids, model.config, num_return_sequences=beam_scorer.num_beams) def test_beam_search_generate_dict_outputs_use_cache(self): for model_class in self.all_generative_model_classes: # enable cache config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - # It is important set set the eos_token_id to None to ensure that no sequences - # shorter than `max_length` can be generated which could lead to flaky circle ci - # failures if the top `num_return_sequences` beams are all shorter than the longest beam - config.eos_token_id = None - config.forced_eos_token_id = None - if not hasattr(config, "use_cache"): self.skipTest("This model doesn't support caching") @@ -1033,28 +640,25 @@ class GenerationTesterMixin: if model.config.is_encoder_decoder: max_length = 4 - logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], - config.eos_token_id, config.forced_bos_token_id, config.forced_eos_token_id, max_length, ) - beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) + beam_kwargs = self._get_beam_kwargs() config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() - output_beam, output_generate = self._beam_search_generate( + output_generate = self._beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, - beam_scorer=beam_scorer, beam_kwargs=beam_kwargs, logits_process_kwargs=logits_process_kwargs, - logits_processor=logits_processor, output_scores=True, output_logits=True, output_hidden_states=True, @@ -1062,12 +666,10 @@ class GenerationTesterMixin: return_dict_in_generate=True, ) - self.assertListEqual(output_generate.sequences.tolist(), output_beam.sequences.tolist()) - - for output in (output_beam, output_generate): - self._check_outputs( - output, input_ids, model.config, use_cache=True, num_return_sequences=beam_scorer.num_beams - ) + self.assertTrue(output_generate.sequences.shape[-1] == max_length) + self._check_outputs( + output_generate, input_ids, model.config, use_cache=True, num_return_sequences=beam_kwargs["num_beams"] + ) @require_accelerate @require_torch_multi_accelerator @@ -1097,32 +699,24 @@ class GenerationTesterMixin: for model_class in self.all_generative_model_classes: config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - # It is important set set the eos_token_id to None to ensure that no sequences - # shorter than `max_length` can be generated which could lead to flaky circle ci - # failures if the top `num_return_sequences` beams are all shorter than the longest beam - config.eos_token_id = None - config.forced_eos_token_id = None - - logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) + _, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1]) model = model_class(config).to(torch_device).eval() - # check `generate()` and `beam_search()` are equal if model.config.is_encoder_decoder: max_length = 4 - beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) + beam_kwargs = self._get_beam_kwargs() - output_generate, output_beam_sample = self._beam_sample_generate( + output_generate = self._beam_sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, - beam_scorer=beam_scorer, beam_kwargs=beam_kwargs, - logits_warper=logits_warper, logits_warper_kwargs=logits_warper_kwargs, ) - self.assertListEqual(output_generate.tolist(), output_beam_sample.tolist()) + + self.assertTrue(output_generate.shape[-1] == max_length) def test_beam_sample_generate_dict_output(self): for model_class in self.all_generative_model_classes: @@ -1131,27 +725,19 @@ class GenerationTesterMixin: # disable cache config.use_cache = False - # It is important set set the eos_token_id to None to ensure that no sequences - # shorter than `max_length` can be generated which could lead to flaky circle ci - # failures if the top `num_return_sequences` beams are all shorter than the longest beam - config.eos_token_id = None - config.forced_eos_token_id = None - model = model_class(config).to(torch_device).eval() - logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) + _, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1]) if model.config.is_encoder_decoder: max_length = 4 - beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) + beam_kwargs = self._get_beam_kwargs() - output_beam_sample, output_generate = self._beam_sample_generate( + output_generate = self._beam_sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, - beam_scorer=beam_scorer, beam_kwargs=beam_kwargs, - logits_warper=logits_warper, logits_warper_kwargs=logits_warper_kwargs, output_scores=True, output_logits=True, @@ -1161,27 +747,18 @@ class GenerationTesterMixin: ) if model.config.is_encoder_decoder: - self.assertIsInstance(output_beam_sample, GenerateBeamEncoderDecoderOutput) self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) # Retrocompatibility check - self.assertIsInstance(output_beam_sample, BeamSampleEncoderDecoderOutput) self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput) else: - self.assertIsInstance(output_beam_sample, GenerateBeamDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) # Retrocompatibility check - self.assertIsInstance(output_beam_sample, BeamSampleDecoderOnlyOutput) self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput) - self.assertListEqual(output_generate.sequences.tolist(), output_beam_sample.sequences.tolist()) - self.assertTrue( - torch.allclose(output_generate["sequences_scores"], output_beam_sample["sequences_scores"], atol=1e-3) + self.assertTrue(output_generate.sequences.shape[-1] == max_length) + self._check_outputs( + output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"] ) - self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],)) - self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) - - for output in (output_beam_sample, output_generate): - self._check_outputs(output, input_ids, model.config, num_return_sequences=beam_scorer.num_beams) def test_generate_without_input_ids(self): config, _, _, max_length = self._get_input_ids_and_config() @@ -1190,6 +767,10 @@ class GenerationTesterMixin: if config.bos_token_id is None: return + # hack in case they are equal, otherwise the attn mask will be [0] + if config.bos_token_id == config.pad_token_id: + config.pad_token_id = None + for model_class in self.all_generative_model_classes: model = model_class(config).to(torch_device) model.eval() @@ -1201,94 +782,65 @@ class GenerationTesterMixin: for model_class in self.all_generative_model_classes: config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - # It is important set set the eos_token_id to None to ensure that no sequences - # shorter than `max_length` can be generated which could lead to flaky circle ci - # failures if the top `num_return_sequences` beams are all shorter than the longest beam - config.eos_token_id = None - config.forced_eos_token_id = None - model = model_class(config).to(torch_device).eval() if model.config.is_encoder_decoder: max_length = 4 - logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], - config.eos_token_id, config.forced_bos_token_id, config.forced_eos_token_id, max_length, - diversity_penalty=2.0, ) # check `generate()` and `group_beam_search()` are equal - beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(input_ids.shape[0], max_length) - output_generate, output_group_beam_search = self._group_beam_search_generate( + beam_kwargs = self._get_diverse_beam_kwargs() + output_generate = self._group_beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, - beam_scorer=beam_scorer, beam_kwargs=beam_kwargs, - logits_processor=logits_processor, logits_process_kwargs=logits_process_kwargs, ) - self.assertListEqual(output_generate.tolist(), output_group_beam_search.tolist()) + self.assertTrue(output_generate.shape[-1] == max_length) - # check `generate()` and `group_beam_search()` are equal for `num_return_sequences` + # check `group_beam_search` for higher than 1 `num_return_sequences` num_return_sequences = 2 - if model.config.is_encoder_decoder: - max_length = 4 - beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs( - input_ids.shape[0], max_length, num_return_sequences=num_return_sequences - ) - output_generate, output_group_beam_search = self._group_beam_search_generate( + beam_kwargs = self._get_diverse_beam_kwargs(num_return_sequences=num_return_sequences) + output_generate = self._group_beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, - beam_scorer=beam_scorer, beam_kwargs=beam_kwargs, - logits_processor=logits_processor, logits_process_kwargs=logits_process_kwargs, ) - self.assertListEqual(output_generate.tolist(), output_group_beam_search.tolist()) + self.assertTrue(output_generate.shape[-1] == max_length) def test_group_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config.use_cache = False - # It is important set set the eos_token_id to None to ensure that no sequences - # shorter than `max_length` can be generated which could lead to flaky circle ci - # failures if the top `num_return_sequences` beams are all shorter than the longest beam - config.eos_token_id = None - config.forced_eos_token_id = None - model = model_class(config).to(torch_device).eval() if model.config.is_encoder_decoder: max_length = 4 - logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], - config.eos_token_id, config.forced_bos_token_id, config.forced_eos_token_id, max_length, - diversity_penalty=2.0, ) - num_return_sequences = 1 - beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs( - input_ids.shape[0], max_length, num_return_sequences=num_return_sequences - ) - output_generate, output_group_beam_search = self._group_beam_search_generate( + beam_kwargs = self._get_diverse_beam_kwargs() + output_generate = self._group_beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, - beam_scorer=beam_scorer, beam_kwargs=beam_kwargs, - logits_processor=logits_processor, logits_process_kwargs=logits_process_kwargs, output_scores=True, output_logits=True, @@ -1297,31 +849,18 @@ class GenerationTesterMixin: return_dict_in_generate=True, ) if model.config.is_encoder_decoder: - self.assertIsInstance(output_group_beam_search, GenerateBeamEncoderDecoderOutput) self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) # Retrocompatibility check - self.assertIsInstance(output_group_beam_search, BeamSearchEncoderDecoderOutput) self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) else: - self.assertIsInstance(output_group_beam_search, GenerateBeamDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) # Retrocompatibility check - self.assertIsInstance(output_group_beam_search, BeamSearchDecoderOnlyOutput) self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) - self.assertListEqual(output_generate.sequences.tolist(), output_group_beam_search.sequences.tolist()) - self.assertTrue( - torch.allclose( - output_generate["sequences_scores"], output_group_beam_search["sequences_scores"], atol=1e-3 - ) + self.assertTrue(output_generate.sequences.shape[-1] == max_length) + self._check_outputs( + output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"] ) - self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],)) - self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) - - for output in (output_group_beam_search, output_generate): - self._check_outputs( - output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams - ) # TODO: @gante @is_flaky() @@ -1329,24 +868,16 @@ class GenerationTesterMixin: for model_class in self.all_generative_model_classes: config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - # It is important set set the eos_token_id to None to ensure that no sequences - # shorter than `max_length` can be generated which could lead to flaky circle ci - # failures if the top `num_return_sequences` beams are all shorter than the longest beam - config.eos_token_id = None - config.forced_eos_token_id = None - model = model_class(config).to(torch_device).eval() max_length = 20 - logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], - config.eos_token_id, config.forced_bos_token_id, config.forced_eos_token_id, max_length, ) - # check `generate()` and `constrained_beam_search()` are equal # Sample constraints min_id = 3 max_id = config.vocab_size @@ -1356,50 +887,40 @@ class GenerationTesterMixin: PhrasalConstraint(force_tokens), ] - beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs( - input_ids.shape[0], max_length, constraints, num_return_sequences=1 - ) - output_generate, output_beam_search = self._constrained_beam_search_generate( + beam_kwargs = self._get_constrained_beam_kwargs() + output_generate = self._constrained_beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, - constrained_beam_scorer=beam_scorer, constraints=constraints, beam_kwargs=beam_kwargs, - logits_processor=logits_processor, logits_process_kwargs=logits_process_kwargs, ) - self.assertListEqual(output_generate.tolist(), output_beam_search.tolist()) + self.assertTrue(output_generate.shape[-1] == max_length) for generation_output in output_generate: self._check_sequence_inside_sequence(force_tokens, generation_output) - # check `generate()` and `constrained_beam_search()` are equal for `num_return_sequences` + # check`constrained_beam_search` for higher than 1 `num_return_sequences` # Sample constraints force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0] constraints = [ PhrasalConstraint(force_tokens), ] - num_return_sequences = 2 max_length = 20 + beam_kwargs = self._get_constrained_beam_kwargs(num_return_sequences=2) - beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs( - input_ids.shape[0], max_length, constraints, num_return_sequences=num_return_sequences - ) - - output_generate, output_beam_search = self._constrained_beam_search_generate( + output_generate = self._constrained_beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, - constrained_beam_scorer=beam_scorer, constraints=constraints, beam_kwargs=beam_kwargs, - logits_processor=logits_processor, logits_process_kwargs=logits_process_kwargs, ) - self.assertListEqual(output_generate.tolist(), output_beam_search.tolist()) + self.assertTrue(output_generate.shape[-1] == max_length) for generation_output in output_generate: self._check_sequence_inside_sequence(force_tokens, generation_output) @@ -1411,19 +932,12 @@ class GenerationTesterMixin: # disable cache config.use_cache = False - # It is important set set the eos_token_id to None to ensure that no sequences - # shorter than `max_length` can be generated which could lead to flaky circle ci - # failures if the top `num_return_sequences` beams are all shorter than the longest beam - config.eos_token_id = None - config.forced_eos_token_id = None - model = model_class(config).to(torch_device).eval() if model.config.is_encoder_decoder: max_length = 20 - logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], - config.eos_token_id, config.forced_bos_token_id, config.forced_eos_token_id, max_length, @@ -1437,18 +951,14 @@ class GenerationTesterMixin: PhrasalConstraint(force_tokens), ] - beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs( - input_ids.shape[0], max_length, constraints, num_return_sequences=1 - ) - output_generate, output_beam_search = self._constrained_beam_search_generate( + beam_kwargs = self._get_constrained_beam_kwargs() + output_generate = self._constrained_beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, - constrained_beam_scorer=beam_scorer, constraints=constraints, beam_kwargs=beam_kwargs, - logits_processor=logits_processor, logits_process_kwargs=logits_process_kwargs, output_scores=True, output_logits=True, @@ -1458,30 +968,20 @@ class GenerationTesterMixin: ) if model.config.is_encoder_decoder: - self.assertIsInstance(output_beam_search, GenerateBeamEncoderDecoderOutput) self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) # Retrocompatibility check - self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput) self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) else: - self.assertIsInstance(output_beam_search, GenerateBeamDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) # Retrocompatibility check - self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput) self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) - self.assertListEqual(output_generate.sequences.tolist(), output_beam_search.sequences.tolist()) - self.assertTrue( - torch.allclose(output_generate["sequences_scores"], output_beam_search["sequences_scores"], atol=1e-3) + self.assertTrue(output_generate.sequences.shape[-1] == max_length) + self._check_outputs( + output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"] ) - self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],)) - self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) - - for output in (output_beam_search, output_generate): - self._check_outputs(output, input_ids, model.config, num_return_sequences=beam_scorer.num_beams) def test_contrastive_generate(self): - # check `generate()` and `contrastive_search()` are equal for model_class in self.all_generative_model_classes: # won't fix: FSMT and Reformer have a different cache variable type (and format). if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): @@ -1497,10 +997,10 @@ class GenerationTesterMixin: # test old generation output for backwards compatibility model = model_class(config).to(torch_device).eval() - output_contrastive, output_generate = self._contrastive_generate( + output_generate = self._contrastive_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length ) - self.assertListEqual(output_contrastive.tolist(), output_generate.tolist()) + self.assertTrue(output_generate.shape[-1] == max_length) def test_contrastive_generate_dict_outputs_use_cache(self): for model_class in self.all_generative_model_classes: @@ -1508,7 +1008,6 @@ class GenerationTesterMixin: if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): self.skipTest("Won't fix: old model with different cache format") - # enable cache config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() # NOTE: contrastive search only works with cache on at the moment. @@ -1518,7 +1017,7 @@ class GenerationTesterMixin: config.is_decoder = True model = model_class(config).to(torch_device).eval() - output_contrastive, output_generate = self._contrastive_generate( + output_generate = self._contrastive_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, @@ -1530,10 +1029,8 @@ class GenerationTesterMixin: return_dict_in_generate=True, ) - self.assertListEqual(output_generate.sequences.tolist(), output_contrastive.sequences.tolist()) - - for output in (output_contrastive, output_generate): - self._check_outputs(output, input_ids, model.config, use_cache=True) + self.assertTrue(output_generate.sequences.shape[-1] == max_length) + self._check_outputs(output_generate, input_ids, model.config, use_cache=True) def test_contrastive_generate_low_memory(self): # Check that choosing 'low_memory' does not change the model output @@ -1591,7 +1088,7 @@ class GenerationTesterMixin: ] ): self.skipTest("May fix in the future: need model-specific fixes") - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=2) + config, input_ids, _, _ = self._get_input_ids_and_config(batch_size=2) # batch_size=1 is ok, but batch_size>1 will cause non-identical output config.use_cache = True @@ -2455,220 +1952,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ], ) - def test_max_length_backward_compat_greedy(self): - # PT-only test: TF doesn't have StoppingCriteria - article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" - bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") - bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( - torch_device - ) - input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - - max_length = 20 - input_ids = input_ids.expand(2, -1) - model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) - input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation( - batch_size=input_ids.shape[0], - model_input_name=bart_model.main_input_name, - model_kwargs=model_kwargs, - decoder_start_token_id=bart_model.config.decoder_start_token_id, - bos_token_id=bart_model.config.bos_token_id, - ) - - with self.assertWarns(UserWarning): - bart_model.greedy_search( - input_ids, - max_length=max_length, - pad_token_id=bart_model.config.pad_token_id, - eos_token_id=bart_model.config.eos_token_id, - **model_kwargs, - ) - - def test_max_length_backward_compat_sample(self): - # PT-only test: TF doesn't have StoppingCriteria - article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" - bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") - bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( - torch_device - ) - input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - - max_length = 20 - input_ids = input_ids.expand(2, -1) - model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) - input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation( - batch_size=input_ids.shape[0], - model_input_name=bart_model.main_input_name, - model_kwargs=model_kwargs, - decoder_start_token_id=bart_model.config.decoder_start_token_id, - bos_token_id=bart_model.config.bos_token_id, - ) - with torch.no_grad(): - with self.assertWarns(UserWarning): - bart_model.sample( - input_ids, - max_length=max_length, - pad_token_id=bart_model.config.pad_token_id, - eos_token_id=bart_model.config.eos_token_id, - **model_kwargs, - ) - - def test_max_length_backward_compat_beam_search(self): - # PT-only test: TF doesn't have StoppingCriteria - article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" - bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") - bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( - torch_device - ) - input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - - batch_size = 1 - max_length = 20 - num_beams = 2 - - input_ids = input_ids.expand(2, -1) - model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) - input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation( - batch_size=input_ids.shape[0], - model_input_name=bart_model.main_input_name, - model_kwargs=model_kwargs, - decoder_start_token_id=bart_model.config.decoder_start_token_id, - bos_token_id=bart_model.config.bos_token_id, - ) - - beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=num_beams, - device=torch_device, - ) - with self.assertWarns(UserWarning): - _ = bart_model.beam_search( - input_ids, num_beams=num_beams, max_length=max_length, beam_scorer=beam_scorer, **model_kwargs - ) - - def test_max_length_backward_compat_group_beam_search(self): - # PT-only test: TF doesn't have StoppingCriteria & group beam search - article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" - bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") - bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( - torch_device - ) - input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - - batch_size = 1 - max_length = 20 - num_beams = 6 - num_beam_groups = 3 - num_return_sequences = num_beams * batch_size - - input_ids = input_ids.expand(6, -1) - model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) - input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation( - batch_size=input_ids.shape[0], - model_input_name=bart_model.main_input_name, - model_kwargs=model_kwargs, - decoder_start_token_id=bart_model.config.decoder_start_token_id, - bos_token_id=bart_model.config.bos_token_id, - ) - - diverse_beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=num_beams, - device=torch_device, - num_beam_hyps_to_keep=num_return_sequences, - num_beam_groups=num_beam_groups, - ) - with self.assertWarns(UserWarning): - bart_model.group_beam_search( - input_ids, diverse_beam_scorer, num_beams=num_beams, max_length=max_length, **model_kwargs - ) - - def test_max_length_warning_if_different(self): - # PT-only test: TF doesn't have StoppingCriteria - article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" - bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") - bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( - torch_device - ) - input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - - batch_size = 1 - - max_length = 20 - num_beams = 6 - num_beam_groups = 3 - num_return_sequences = num_beams * batch_size - stopping_criteria_max_length = 18 - stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)]) - - # Greedy - input_ids = input_ids.expand(6, -1) - model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) - input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation( - batch_size=input_ids.shape[0], - model_input_name=bart_model.main_input_name, - model_kwargs=model_kwargs, - decoder_start_token_id=bart_model.config.decoder_start_token_id, - bos_token_id=bart_model.config.bos_token_id, - ) - - with self.assertWarns(UserWarning): - bart_model.greedy_search( - input_ids, - max_length=max_length, - pad_token_id=bart_model.config.pad_token_id, - stopping_criteria=stopping_criteria, - eos_token_id=bart_model.config.eos_token_id, - **model_kwargs, - ) - - # Sample - with self.assertWarns(UserWarning): - with torch.no_grad(): - bart_model.sample( - input_ids, - max_length=max_length, - stopping_criteria=stopping_criteria, - pad_token_id=bart_model.config.pad_token_id, - eos_token_id=bart_model.config.eos_token_id, - **model_kwargs, - ) - - # Beam - beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=num_beams, - device=torch_device, - ) - with self.assertWarns(UserWarning): - with torch.no_grad(): - bart_model.beam_search( - input_ids, - num_beams=num_beams, - stopping_criteria=stopping_criteria, - max_length=max_length, - beam_scorer=beam_scorer, - **model_kwargs, - ) - - # Grouped beam search - diverse_beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=num_beams, - device=torch_device, - num_beam_hyps_to_keep=num_return_sequences, - num_beam_groups=num_beam_groups, - ) - with self.assertWarns(UserWarning): - bart_model.group_beam_search( - input_ids, - diverse_beam_scorer, - stopping_criteria=stopping_criteria, - num_beams=num_beams, - max_length=max_length, - **model_kwargs, - ) - def test_max_length_if_input_embeds(self): # PT-only test: TF doesn't have StoppingCriteria article = "Today a dragon flew over Paris." @@ -2819,31 +2102,15 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi # lets run beam search using 3 beams num_beams = 3 # define decoder start token ids - input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) + input_ids = torch.ones((1, 1), device=model.device, dtype=torch.long) input_ids = input_ids * model.config.decoder_start_token_id # add encoder_outputs to model keyword arguments - model_kwargs = { - "encoder_outputs": model.get_encoder()( - encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True - ) - } + model_kwargs = {"encoder_outputs": model.get_encoder()(encoder_input_ids, return_dict=True)} - # instantiate beam scorer - beam_scorer = BeamSearchScorer( - batch_size=1, - num_beams=num_beams, - device=model.device, + outputs = model.generate( + input_ids, num_beams=num_beams, min_length=5, eos_token_id=model.config.eos_token_id, **model_kwargs ) - - # instantiate logits processors - logits_processor = LogitsProcessorList( - [ - MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), - ] - ) - - outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) self.assertListEqual(outputs, ["Wie alt bist du?"]) @@ -3042,34 +2309,22 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi # lets run beam search using 5 beams num_beams = 5 # define decoder start token ids - input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) + input_ids = torch.ones((1, 1), device=model.device, dtype=torch.long) input_ids = input_ids * model.config.decoder_start_token_id # add encoder_outputs to model keyword arguments - model_kwargs = { - "encoder_outputs": model.get_encoder()( - encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True - ) - } + model_kwargs = {"encoder_outputs": model.get_encoder()(encoder_input_ids, return_dict=True)} constraint_str = "sind" constraint_token_ids = tokenizer.encode(constraint_str)[:-1] # remove eos token - constraints = [PhrasalConstraint(token_ids=constraint_token_ids)] - # instantiate beam scorer - beam_scorer = ConstrainedBeamSearchScorer( - batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints - ) - - # instantiate logits processors - logits_processor = LogitsProcessorList( - [ - MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), - ] - ) - - outputs = model.constrained_beam_search( - input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs + outputs = model.generate( + input_ids, + num_beams=num_beams, + force_words_ids=[constraint_token_ids], + min_length=5, + eos_token_id=model.config.eos_token_id, + **model_kwargs, ) outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index e2e7da36ea..f7ceb0a8bf 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -55,8 +55,6 @@ if is_torch_available(): from transformers.generation import ( GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, - InfNanRemoveLogitsProcessor, - LogitsProcessorList, ) @@ -247,19 +245,17 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste return config, input_ids, attention_mask, max_length @staticmethod - def _get_logits_processor_and_kwargs( + def _get_logits_processor_and_warper_kwargs( input_length, - eos_token_id, forced_bos_token_id=None, forced_eos_token_id=None, max_length=None, - diversity_penalty=None, ): process_kwargs = { "min_length": input_length + 1 if max_length is None else max_length - 1, } - logits_processor = LogitsProcessorList() - return process_kwargs, logits_processor + warper_kwargs = {} + return process_kwargs, warper_kwargs # override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform # additional post-processing in the former @@ -269,7 +265,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config.use_cache = False model = model_class(config).to(torch_device).eval() - output_greedy, output_generate = self._greedy_generate( + output_generate = self._greedy_generate( model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), @@ -280,9 +276,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste return_dict_in_generate=True, ) - self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) - self.assertNotIn(config.pad_token_id, output_generate) # override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform @@ -295,7 +289,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() - output_greedy, output_generate = self._greedy_generate( + output_generate = self._greedy_generate( model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), @@ -306,7 +300,6 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste return_dict_in_generate=True, ) - self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) # override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform @@ -316,28 +309,21 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() - process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], - model.config.eos_token_id, - forced_bos_token_id=model.config.forced_bos_token_id, - forced_eos_token_id=model.config.forced_eos_token_id, max_length=max_length, ) - logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2) # check `generate()` and `sample()` are equal - output_sample, output_generate = self._sample_generate( + output_generate = self._sample_generate( model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), max_length=max_length, num_return_sequences=3, - logits_processor=logits_processor, - logits_warper=logits_warper, logits_warper_kwargs=logits_warper_kwargs, process_kwargs=process_kwargs, ) - self.assertIsInstance(output_sample, torch.Tensor) self.assertIsInstance(output_generate, torch.Tensor) # override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform @@ -349,23 +335,17 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste config.use_cache = False model = model_class(config).to(torch_device).eval() - process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], - model.config.eos_token_id, - forced_bos_token_id=model.config.forced_bos_token_id, - forced_eos_token_id=model.config.forced_eos_token_id, max_length=max_length, ) - logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) - output_sample, output_generate = self._sample_generate( + output_generate = self._sample_generate( model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), max_length=max_length, num_return_sequences=1, - logits_processor=logits_processor, - logits_warper=logits_warper, logits_warper_kwargs=logits_warper_kwargs, process_kwargs=process_kwargs, output_scores=True, @@ -374,7 +354,6 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste return_dict_in_generate=True, ) - self.assertIsInstance(output_sample, GenerateDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) def test_greedy_generate_stereo_outputs(self): @@ -382,7 +361,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config.audio_channels = 2 model = model_class(config).to(torch_device).eval() - output_greedy, output_generate = self._greedy_generate( + output_generate = self._greedy_generate( model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), @@ -393,7 +372,6 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste return_dict_in_generate=True, ) - self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) self.assertNotIn(config.pad_token_id, output_generate) @@ -834,10 +812,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long) # generate max 3 tokens - decoder_input_ids = inputs_dict["decoder_input_ids"] - max_length = decoder_input_ids.shape[-1] + 3 - decoder_input_ids = decoder_input_ids[: batch_size * config.decoder.num_codebooks, :] - return config, input_ids, attention_mask, decoder_input_ids, max_length + max_length = 3 + return config, input_ids, attention_mask, max_length # override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are # different modalities -> different shapes) @@ -846,18 +822,14 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, model, input_ids, attention_mask, - decoder_input_ids, max_length, output_scores=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, ): - logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], - eos_token_id=model.config.eos_token_id, - forced_bos_token_id=model.config.forced_bos_token_id, - forced_eos_token_id=model.config.forced_eos_token_id, max_length=max_length, ) @@ -876,28 +848,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, **model_kwargs, ) - encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( - model, - input_ids, - attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - with torch.no_grad(): - model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - output_greedy = model.greedy_search( - decoder_input_ids, - max_length=max_length, - logits_processor=logits_processor, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, - encoder_outputs=encoder_outputs, - **model_kwargs, - ) - return output_greedy, output_generate + return output_generate # override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are # different modalities -> different shapes) @@ -906,11 +857,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, model, input_ids, attention_mask, - decoder_input_ids, max_length, num_return_sequences, - logits_processor, - logits_warper, logits_warper_kwargs, process_kwargs, output_scores=False, @@ -936,62 +884,31 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, **model_kwargs, ) - torch.manual_seed(0) - encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( - model, - input_ids, - attention_mask, - num_interleave=num_return_sequences, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - # prevent flaky generation test failures - logits_processor.append(InfNanRemoveLogitsProcessor()) - - with torch.no_grad(): - model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - output_sample = model.sample( - decoder_input_ids.repeat_interleave(num_return_sequences, dim=0), - max_length=max_length, - logits_processor=logits_processor, - logits_warper=logits_warper, - output_scores=output_scores, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict_in_generate=return_dict_in_generate, - encoder_outputs=encoder_outputs, - **model_kwargs, - ) - - return output_sample, output_generate + return output_generate @staticmethod - def _get_logits_processor_and_kwargs( + def _get_logits_processor_and_warper_kwargs( input_length, - eos_token_id, forced_bos_token_id=None, forced_eos_token_id=None, max_length=None, - diversity_penalty=None, ): process_kwargs = { "min_length": input_length + 1 if max_length is None else max_length - 1, } - logits_processor = LogitsProcessorList() - return process_kwargs, logits_processor + warper_kwargs = {} + return process_kwargs, warper_kwargs def test_greedy_generate_dict_outputs(self): for model_class in self.greedy_sample_model_classes: # disable cache - config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config.use_cache = False model = model_class(config).to(torch_device).eval() - output_greedy, output_generate = self._greedy_generate( + output_generate = self._greedy_generate( model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), - decoder_input_ids=decoder_input_ids, max_length=max_length, output_scores=True, output_hidden_states=True, @@ -999,7 +916,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, return_dict_in_generate=True, ) - self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput) self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) self.assertNotIn(config.pad_token_id, output_generate) @@ -1007,16 +923,15 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, def test_greedy_generate_dict_outputs_use_cache(self): for model_class in self.greedy_sample_model_classes: # enable cache - config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() - output_greedy, output_generate = self._greedy_generate( + output_generate = self._greedy_generate( model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), - decoder_input_ids=decoder_input_ids, max_length=max_length, output_scores=True, output_hidden_states=True, @@ -1024,64 +939,48 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, return_dict_in_generate=True, ) - self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput) self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) def test_sample_generate(self): for model_class in self.greedy_sample_model_classes: - config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() - process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], - model.config.eos_token_id, - forced_bos_token_id=model.config.forced_bos_token_id, - forced_eos_token_id=model.config.forced_eos_token_id, max_length=max_length, ) - logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2) # check `generate()` and `sample()` are equal - output_sample, output_generate = self._sample_generate( + output_generate = self._sample_generate( model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), - decoder_input_ids=decoder_input_ids, max_length=max_length, num_return_sequences=1, - logits_processor=logits_processor, - logits_warper=logits_warper, logits_warper_kwargs=logits_warper_kwargs, process_kwargs=process_kwargs, ) - self.assertIsInstance(output_sample, torch.Tensor) self.assertIsInstance(output_generate, torch.Tensor) def test_sample_generate_dict_output(self): for model_class in self.greedy_sample_model_classes: # disable cache - config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config.use_cache = False model = model_class(config).to(torch_device).eval() - process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs( input_ids.shape[-1], - model.config.eos_token_id, - forced_bos_token_id=model.config.forced_bos_token_id, - forced_eos_token_id=model.config.forced_eos_token_id, max_length=max_length, ) - logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) - output_sample, output_generate = self._sample_generate( + output_generate = self._sample_generate( model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), - decoder_input_ids=decoder_input_ids, max_length=max_length, num_return_sequences=3, - logits_processor=logits_processor, - logits_warper=logits_warper, logits_warper_kwargs=logits_warper_kwargs, process_kwargs=process_kwargs, output_scores=True, @@ -1090,11 +989,10 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, return_dict_in_generate=True, ) - self.assertIsInstance(output_sample, GenerateEncoderDecoderOutput) self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) def test_generate_without_input_ids(self): - config, _, _, _, max_length = self._get_input_ids_and_config() + config, _, _, max_length = self._get_input_ids_and_config() # if no bos token id => cannot generate from None if config.bos_token_id is None: @@ -1123,15 +1021,14 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, def test_greedy_generate_stereo_outputs(self): for model_class in self.greedy_sample_model_classes: - config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config.audio_channels = 2 model = model_class(config).to(torch_device).eval() - output_greedy, output_generate = self._greedy_generate( + output_generate = self._greedy_generate( model=model, input_ids=input_ids.to(torch_device), attention_mask=attention_mask.to(torch_device), - decoder_input_ids=decoder_input_ids, max_length=max_length, output_scores=True, output_hidden_states=True, @@ -1139,7 +1036,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, return_dict_in_generate=True, ) - self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput) self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) self.assertNotIn(config.pad_token_id, output_generate)