From 59d5edef34ae0fa56065a2e863736d4f133c558b Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 3 Feb 2023 22:01:25 +0100 Subject: [PATCH] Avoid flaky generation sampling tests (#21445) * fix * fix --------- Co-authored-by: ydshieh --- tests/generation/test_utils.py | 2 +- .../switch_transformers/test_modeling_switch_transformers.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index a97f036e12..cb1c2460db 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -780,7 +780,7 @@ class GenerationTesterMixin: forced_eos_token_id=model.config.forced_eos_token_id, max_length=max_length, ) - logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) + logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2) # check `generate()` and `sample()` are equal output_sample, output_generate = self._sample_generate( diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 1afeb2e484..20e089a313 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -621,7 +621,7 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, unitt config.forced_eos_token_id = None model = model_class.from_pretrained("google/switch-base-8").to(torch_device).eval() - logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) + logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2) num_return_sequences = 2 if model.config.is_encoder_decoder: @@ -670,7 +670,7 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, unitt config.eos_token_id = None config.forced_eos_token_id = None - logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) + logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2) model = model_class.from_pretrained("google/switch-base-8").to(torch_device).eval()