Fix multiple eos_token_ids in model.generate(...) (#21461)

* add tests with multiple eos_token_ids

* make math.prod instead of sum

* make fixup

* fix long and also use np.prod since math.prod does not exist <python 3.8

* make fixup

* add prod util

* use prod util instead of np.prod

* make fixup

* previous .long location

* use tensor ops

* remove prod

* remove prod

* update device

* make fixup

* fix none
This commit is contained in:
Motoki Wu
2023-02-08 10:48:46 -08:00
committed by GitHub
parent 06d940efc3
commit 9960506cbe
2 changed files with 19 additions and 10 deletions

View File

@@ -2609,7 +2609,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertTrue(expectation == len(generated_tokens[0]))
torch.manual_seed(0)
eos_token_id = [873]
eos_token_id = [873, 198]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
@@ -2634,7 +2634,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertTrue(expectation == len(generated_tokens[0]))
torch.manual_seed(0)
eos_token_id = [225]
eos_token_id = [225, 198]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
@@ -2660,7 +2660,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertTrue(expectation == len(generated_tokens[0]))
torch.manual_seed(0)
eos_token_id = [846]
eos_token_id = [846, 198]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
@@ -2683,7 +2683,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertTrue(expectation == len(generated_tokens[0]))
torch.manual_seed(0)
eos_token_id = [873]
eos_token_id = [873, 198]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))