diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index dac8337d52..e54a3579cb 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -625,6 +625,12 @@ class GenerationTesterMixin: def test_beam_search_generate(self): for model_class in self.all_generative_model_classes: config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + # It is important set set the eos_token_id to None to ensure that no sequences + # shorter than `max_length` can be generated which could lead to flaky circle ci + # failures if the top `num_return_sequences` beams are all shorter than the longest beam + config.eos_token_id = None + model = model_class(config).to(torch_device).eval() logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( @@ -669,9 +675,16 @@ class GenerationTesterMixin: def test_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: - # disable cache config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + # disable cache config.use_cache = False + + # It is important set set the eos_token_id to None to ensure that no sequences + # shorter than `max_length` can be generated which could lead to flaky circle ci + # failures if the top `num_return_sequences` beams are all shorter than the longest beam + config.eos_token_id = None + model = model_class(config).to(torch_device).eval() logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( input_ids.shape[-1], config.eos_token_id @@ -715,11 +728,15 @@ class GenerationTesterMixin: # enable cache config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + # It is important set set the eos_token_id to None to ensure that no sequences + # shorter than `max_length` can be generated which could lead to flaky circle ci + # failures if the top `num_return_sequences` beams are all shorter than the longest beam + config.eos_token_id = None + if not hasattr(config, "use_cache"): # only relevant if model has "use_cache" return - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( @@ -758,7 +775,12 @@ class GenerationTesterMixin: def test_beam_sample_generate(self): for model_class in self.all_generative_model_classes: config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - print("Return dict", config.return_dict) + + # It is important set set the eos_token_id to None to ensure that no sequences + # shorter than `max_length` can be generated which could lead to flaky circle ci + # failures if the top `num_return_sequences` beams are all shorter than the longest beam + config.eos_token_id = None + logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) model = model_class(config).to(torch_device).eval() @@ -788,9 +810,16 @@ class GenerationTesterMixin: def test_beam_sample_generate_dict_output(self): for model_class in self.all_generative_model_classes: - # disable cache config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + # disable cache config.use_cache = False + + # It is important set set the eos_token_id to None to ensure that no sequences + # shorter than `max_length` can be generated which could lead to flaky circle ci + # failures if the top `num_return_sequences` beams are all shorter than the longest beam + config.eos_token_id = None + model = model_class(config).to(torch_device).eval() logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) @@ -859,6 +888,11 @@ class GenerationTesterMixin: for model_class in self.all_generative_model_classes: config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + # It is important set set the eos_token_id to None to ensure that no sequences + # shorter than `max_length` can be generated which could lead to flaky circle ci + # failures if the top `num_return_sequences` beams are all shorter than the longest beam + config.eos_token_id = None + logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( input_ids.shape[-1], config.eos_token_id, diversity_penalty=2.0 ) @@ -904,6 +938,12 @@ class GenerationTesterMixin: for model_class in self.all_generative_model_classes: config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config.use_cache = False + + # It is important set set the eos_token_id to None to ensure that no sequences + # shorter than `max_length` can be generated which could lead to flaky circle ci + # failures if the top `num_return_sequences` beams are all shorter than the longest beam + config.eos_token_id = None + model = model_class(config).to(torch_device).eval() logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(