Move eos_token_id to stopping criteria (#29459)
* add eos stopping criteria * minor fix * Update tests/generation/test_stopping_criteria.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * check eos is not None and fix tests * make style and fixup * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update tests/generation/test_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update tests/generation/test_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/__init__.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * camel case everywhere * call stopping criteria list for candidate ids * make style and fixup * Empty commit * Empty commit to pass flaky test * set max length in PromptLookupCandidateGenerator * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * lets fix this typo in docs * Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * update PR * empty commit --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
31c575bcf1
commit
0efcf32351
@@ -26,6 +26,7 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.generation import (
|
||||
EosTokenCriteria,
|
||||
MaxLengthCriteria,
|
||||
MaxNewTokensCriteria,
|
||||
MaxTimeCriteria,
|
||||
@@ -98,6 +99,22 @@ class StoppingCriteriaTestCase(unittest.TestCase):
|
||||
criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2)
|
||||
self.assertTrue(all(criteria(input_ids, scores)))
|
||||
|
||||
def test_eos_token_criteria(self):
|
||||
criteria = EosTokenCriteria(eos_token_id=0)
|
||||
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
input_ids[:, -1] = 0
|
||||
self.assertTrue(all(criteria(input_ids, scores)))
|
||||
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
input_ids[:2, -1] = 0
|
||||
input_ids[2, -1] = 1
|
||||
self.assertListEqual(criteria(input_ids, scores).tolist(), [True, True, False])
|
||||
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
input_ids[:, -1] = 1
|
||||
self.assertListEqual(criteria(input_ids, scores).tolist(), [False, False, False])
|
||||
|
||||
def test_validate_stopping_criteria(self):
|
||||
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10)
|
||||
|
||||
|
||||
@@ -1899,14 +1899,12 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
]
|
||||
)
|
||||
last_assistant_token_is_eos = False
|
||||
max_matches = 5
|
||||
validated_tokens, n_matches = _speculative_sampling(
|
||||
candidate_input_ids,
|
||||
candidate_logits,
|
||||
candidate_length,
|
||||
new_logits,
|
||||
last_assistant_token_is_eos,
|
||||
max_matches,
|
||||
)
|
||||
self.assertTrue(n_matches.item() == 2)
|
||||
self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8])
|
||||
|
||||
Reference in New Issue
Block a user