Fix multiple eos_token_ids in model.generate(...) (#21461)
* add tests with multiple eos_token_ids * make math.prod instead of sum * make fixup * fix long and also use np.prod since math.prod does not exist <python 3.8 * make fixup * add prod util * use prod util instead of np.prod * make fixup * previous .long location * use tensor ops * remove prod * remove prod * update device * make fixup * fix none
This commit is contained in:
@@ -1757,6 +1757,7 @@ class GenerationMixin:
|
|||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||||
if isinstance(eos_token_id, int):
|
if isinstance(eos_token_id, int):
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
|
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
||||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||||
output_attentions = (
|
output_attentions = (
|
||||||
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||||
@@ -1980,8 +1981,10 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# if eos_token was found in one sentence, set sentence to finished
|
# if eos_token was found in one sentence, set sentence to finished
|
||||||
if eos_token_id is not None:
|
if eos_token_id_tensor is not None:
|
||||||
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
|
unfinished_sequences = (
|
||||||
|
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
||||||
|
)
|
||||||
|
|
||||||
# stop when each sentence is finished, or if we exceed the maximum length
|
# stop when each sentence is finished, or if we exceed the maximum length
|
||||||
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||||
@@ -2129,6 +2132,7 @@ class GenerationMixin:
|
|||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||||
if isinstance(eos_token_id, int):
|
if isinstance(eos_token_id, int):
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
|
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
||||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||||
output_attentions = (
|
output_attentions = (
|
||||||
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||||
@@ -2223,8 +2227,10 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# if eos_token was found in one sentence, set sentence to finished
|
# if eos_token was found in one sentence, set sentence to finished
|
||||||
if eos_token_id is not None:
|
if eos_token_id_tensor is not None:
|
||||||
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
|
unfinished_sequences = (
|
||||||
|
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
||||||
|
)
|
||||||
|
|
||||||
# stop when each sentence is finished, or if we exceed the maximum length
|
# stop when each sentence is finished, or if we exceed the maximum length
|
||||||
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||||
@@ -2392,6 +2398,7 @@ class GenerationMixin:
|
|||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||||
if isinstance(eos_token_id, int):
|
if isinstance(eos_token_id, int):
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
|
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
||||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||||
output_attentions = (
|
output_attentions = (
|
||||||
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||||
@@ -2489,8 +2496,10 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# if eos_token was found in one sentence, set sentence to finished
|
# if eos_token was found in one sentence, set sentence to finished
|
||||||
if eos_token_id is not None:
|
if eos_token_id_tensor is not None:
|
||||||
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
|
unfinished_sequences = (
|
||||||
|
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
||||||
|
)
|
||||||
|
|
||||||
# stop when each sentence is finished, or if we exceed the maximum length
|
# stop when each sentence is finished, or if we exceed the maximum length
|
||||||
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||||
|
|||||||
@@ -2609,7 +2609,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
eos_token_id = [873]
|
eos_token_id = [873, 198]
|
||||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||||
|
|
||||||
@@ -2634,7 +2634,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
eos_token_id = [225]
|
eos_token_id = [225, 198]
|
||||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||||
|
|
||||||
@@ -2660,7 +2660,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
eos_token_id = [846]
|
eos_token_id = [846, 198]
|
||||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||||
|
|
||||||
@@ -2683,7 +2683,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
eos_token_id = [873]
|
eos_token_id = [873, 198]
|
||||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user