Fix test_eos_token_id_int_and_list_top_k_top_sampling (#22826)
* fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -2515,12 +2515,14 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
tokens = tokenizer(text, return_tensors="pt").to(torch_device)
|
tokens = tokenizer(text, return_tensors="pt").to(torch_device)
|
||||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||||
|
|
||||||
torch.manual_seed(0)
|
# Only some seeds will work both on CPU/GPU for a fixed `expectation` value.
|
||||||
|
# The selected seed is not guaranteed to work on all torch versions.
|
||||||
|
torch.manual_seed(1)
|
||||||
eos_token_id = 846
|
eos_token_id = 846
|
||||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(1)
|
||||||
eos_token_id = [846, 198]
|
eos_token_id = [846, 198]
|
||||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||||
|
|||||||
Reference in New Issue
Block a user