From 04eae987f3f856724fddba9d6f9afbcacfdfed27 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 10 Jan 2025 17:31:03 +0100 Subject: [PATCH] Fix flaky `test_beam_search_low_memory` (#35611) * fix * fix * fix * fix * fix --------- Co-authored-by: ydshieh --- tests/generation/test_utils.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index ad2aeacbfe..510f3fe1a9 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -204,6 +204,8 @@ class GenerationTesterMixin: "vision_start_token_id", ]: token_index = getattr(config, key, None) + if token_index is None and hasattr(self, "model_tester"): + token_index = getattr(self.model_tester, key, None) if token_index is not None and token_index < config.get_text_config().vocab_size: logits_processor_kwargs["bad_words_ids"].append([token_index]) @@ -1077,7 +1079,10 @@ class GenerationTesterMixin: ): self.skipTest(reason="May fix in the future: need model-specific fixes") + set_model_tester_for_less_flaky_test(self) + config, inputs_dict = self.prepare_config_and_inputs_for_generate() + set_config_for_less_flaky_test(config) # batch_size=1 is ok, but batch_size>1 will cause non-identical output config.use_cache = True @@ -1085,6 +1090,9 @@ class GenerationTesterMixin: # test output equality of low versus high memory model = model_class(config).to(torch_device).eval() + set_model_for_less_flaky_test(model) + + logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config) low_output = model.generate( **inputs_dict, @@ -1093,6 +1101,10 @@ class GenerationTesterMixin: early_stopping=True, low_memory=True, use_cache=True, + output_scores=True, + output_logits=True, + return_dict_in_generate=True, + **logits_processor_kwargs, ) high_output = model.generate( @@ -1102,8 +1114,13 @@ class GenerationTesterMixin: early_stopping=True, low_memory=False, use_cache=True, + output_scores=True, + output_logits=True, + return_dict_in_generate=True, + **logits_processor_kwargs, ) - self.assertListEqual(low_output.tolist(), high_output.tolist()) + # The two outputs must match and their shape must be as expected + self._check_similar_generate_outputs(low_output, high_output) @pytest.mark.generate @parameterized.expand([("random",), ("same",)])