From 90247d3e011fef58e2b4e95f70ae2c1aacaf61db Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Tue, 18 Apr 2023 16:04:51 +0200 Subject: [PATCH] Fix `test_eos_token_id_int_and_list_top_k_top_sampling` (#22826) * fix --------- Co-authored-by: ydshieh --- tests/generation/test_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index dffaba4fb6..86963a1269 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2515,12 +2515,14 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi tokens = tokenizer(text, return_tensors="pt").to(torch_device) model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) - torch.manual_seed(0) + # Only some seeds will work both on CPU/GPU for a fixed `expectation` value. + # The selected seed is not guaranteed to work on all torch versions. + torch.manual_seed(1) eos_token_id = 846 generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) self.assertTrue(expectation == len(generated_tokens[0])) - torch.manual_seed(0) + torch.manual_seed(1) eos_token_id = [846, 198] generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) self.assertTrue(expectation == len(generated_tokens[0]))