Avoid flaky generation sampling tests (#21445)
* fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -780,7 +780,7 @@ class GenerationTesterMixin:
|
|||||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
forced_eos_token_id=model.config.forced_eos_token_id,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
)
|
)
|
||||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2)
|
||||||
|
|
||||||
# check `generate()` and `sample()` are equal
|
# check `generate()` and `sample()` are equal
|
||||||
output_sample, output_generate = self._sample_generate(
|
output_sample, output_generate = self._sample_generate(
|
||||||
|
|||||||
@@ -621,7 +621,7 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, unitt
|
|||||||
config.forced_eos_token_id = None
|
config.forced_eos_token_id = None
|
||||||
|
|
||||||
model = model_class.from_pretrained("google/switch-base-8").to(torch_device).eval()
|
model = model_class.from_pretrained("google/switch-base-8").to(torch_device).eval()
|
||||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2)
|
||||||
|
|
||||||
num_return_sequences = 2
|
num_return_sequences = 2
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
@@ -670,7 +670,7 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, unitt
|
|||||||
config.eos_token_id = None
|
config.eos_token_id = None
|
||||||
config.forced_eos_token_id = None
|
config.forced_eos_token_id = None
|
||||||
|
|
||||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2)
|
||||||
|
|
||||||
model = model_class.from_pretrained("google/switch-base-8").to(torch_device).eval()
|
model = model_class.from_pretrained("google/switch-base-8").to(torch_device).eval()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user