Fix flaky test_beam_search_low_memory (#35611)
* fix * fix * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -204,6 +204,8 @@ class GenerationTesterMixin:
|
|||||||
"vision_start_token_id",
|
"vision_start_token_id",
|
||||||
]:
|
]:
|
||||||
token_index = getattr(config, key, None)
|
token_index = getattr(config, key, None)
|
||||||
|
if token_index is None and hasattr(self, "model_tester"):
|
||||||
|
token_index = getattr(self.model_tester, key, None)
|
||||||
if token_index is not None and token_index < config.get_text_config().vocab_size:
|
if token_index is not None and token_index < config.get_text_config().vocab_size:
|
||||||
logits_processor_kwargs["bad_words_ids"].append([token_index])
|
logits_processor_kwargs["bad_words_ids"].append([token_index])
|
||||||
|
|
||||||
@@ -1077,7 +1079,10 @@ class GenerationTesterMixin:
|
|||||||
):
|
):
|
||||||
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
||||||
|
|
||||||
|
set_model_tester_for_less_flaky_test(self)
|
||||||
|
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
|
set_config_for_less_flaky_test(config)
|
||||||
# batch_size=1 is ok, but batch_size>1 will cause non-identical output
|
# batch_size=1 is ok, but batch_size>1 will cause non-identical output
|
||||||
|
|
||||||
config.use_cache = True
|
config.use_cache = True
|
||||||
@@ -1085,6 +1090,9 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
# test output equality of low versus high memory
|
# test output equality of low versus high memory
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
set_model_for_less_flaky_test(model)
|
||||||
|
|
||||||
|
logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config)
|
||||||
|
|
||||||
low_output = model.generate(
|
low_output = model.generate(
|
||||||
**inputs_dict,
|
**inputs_dict,
|
||||||
@@ -1093,6 +1101,10 @@ class GenerationTesterMixin:
|
|||||||
early_stopping=True,
|
early_stopping=True,
|
||||||
low_memory=True,
|
low_memory=True,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
|
output_scores=True,
|
||||||
|
output_logits=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
**logits_processor_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
high_output = model.generate(
|
high_output = model.generate(
|
||||||
@@ -1102,8 +1114,13 @@ class GenerationTesterMixin:
|
|||||||
early_stopping=True,
|
early_stopping=True,
|
||||||
low_memory=False,
|
low_memory=False,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
|
output_scores=True,
|
||||||
|
output_logits=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
**logits_processor_kwargs,
|
||||||
)
|
)
|
||||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
# The two outputs must match and their shape must be as expected
|
||||||
|
self._check_similar_generate_outputs(low_output, high_output)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
@parameterized.expand([("random",), ("same",)])
|
@parameterized.expand([("random",), ("same",)])
|
||||||
|
|||||||
Reference in New Issue
Block a user