Avoid flaky generation sampling tests (#21445)

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-02-03 22:01:25 +01:00
committed by GitHub
parent 31c351c4d3
commit 59d5edef34
2 changed files with 3 additions and 3 deletions

View File

@@ -780,7 +780,7 @@ class GenerationTesterMixin:
forced_eos_token_id=model.config.forced_eos_token_id,
max_length=max_length,
)
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2)
# check `generate()` and `sample()` are equal
output_sample, output_generate = self._sample_generate(