Fix multiple eos_token_ids in model.generate(...) (#21461)
* add tests with multiple eos_token_ids * make math.prod instead of sum * make fixup * fix long and also use np.prod since math.prod does not exist <python 3.8 * make fixup * add prod util * use prod util instead of np.prod * make fixup * previous .long location * use tensor ops * remove prod * remove prod * update device * make fixup * fix none
This commit is contained in:
@@ -2609,7 +2609,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
torch.manual_seed(0)
|
||||
eos_token_id = [873]
|
||||
eos_token_id = [873, 198]
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
@@ -2634,7 +2634,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
torch.manual_seed(0)
|
||||
eos_token_id = [225]
|
||||
eos_token_id = [225, 198]
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
@@ -2660,7 +2660,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
torch.manual_seed(0)
|
||||
eos_token_id = [846]
|
||||
eos_token_id = [846, 198]
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
@@ -2683,7 +2683,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
torch.manual_seed(0)
|
||||
eos_token_id = [873]
|
||||
eos_token_id = [873, 198]
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user