Fix flaky test_beam_search_low_memory (#35611)
* fix * fix * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -204,6 +204,8 @@ class GenerationTesterMixin:
|
||||
"vision_start_token_id",
|
||||
]:
|
||||
token_index = getattr(config, key, None)
|
||||
if token_index is None and hasattr(self, "model_tester"):
|
||||
token_index = getattr(self.model_tester, key, None)
|
||||
if token_index is not None and token_index < config.get_text_config().vocab_size:
|
||||
logits_processor_kwargs["bad_words_ids"].append([token_index])
|
||||
|
||||
@@ -1077,7 +1079,10 @@ class GenerationTesterMixin:
|
||||
):
|
||||
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
||||
|
||||
set_model_tester_for_less_flaky_test(self)
|
||||
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
set_config_for_less_flaky_test(config)
|
||||
# batch_size=1 is ok, but batch_size>1 will cause non-identical output
|
||||
|
||||
config.use_cache = True
|
||||
@@ -1085,6 +1090,9 @@ class GenerationTesterMixin:
|
||||
|
||||
# test output equality of low versus high memory
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
set_model_for_less_flaky_test(model)
|
||||
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config)
|
||||
|
||||
low_output = model.generate(
|
||||
**inputs_dict,
|
||||
@@ -1093,6 +1101,10 @@ class GenerationTesterMixin:
|
||||
early_stopping=True,
|
||||
low_memory=True,
|
||||
use_cache=True,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
return_dict_in_generate=True,
|
||||
**logits_processor_kwargs,
|
||||
)
|
||||
|
||||
high_output = model.generate(
|
||||
@@ -1102,8 +1114,13 @@ class GenerationTesterMixin:
|
||||
early_stopping=True,
|
||||
low_memory=False,
|
||||
use_cache=True,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
return_dict_in_generate=True,
|
||||
**logits_processor_kwargs,
|
||||
)
|
||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||
# The two outputs must match and their shape must be as expected
|
||||
self._check_similar_generate_outputs(low_output, high_output)
|
||||
|
||||
@pytest.mark.generate
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
|
||||
Reference in New Issue
Block a user