Move eos_token_id to stopping criteria (#29459)
* add eos stopping criteria * minor fix * Update tests/generation/test_stopping_criteria.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * check eos is not None and fix tests * make style and fixup * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update tests/generation/test_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update tests/generation/test_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/__init__.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * camel case everywhere * call stopping criteria list for candidate ids * make style and fixup * Empty commit * Empty commit to pass flaky test * set max length in PromptLookupCandidateGenerator * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * lets fix this typo in docs * Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * update PR * empty commit --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
31c575bcf1
commit
0efcf32351
@@ -1899,14 +1899,12 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
]
|
||||
)
|
||||
last_assistant_token_is_eos = False
|
||||
max_matches = 5
|
||||
validated_tokens, n_matches = _speculative_sampling(
|
||||
candidate_input_ids,
|
||||
candidate_logits,
|
||||
candidate_length,
|
||||
new_logits,
|
||||
last_assistant_token_is_eos,
|
||||
max_matches,
|
||||
)
|
||||
self.assertTrue(n_matches.item() == 2)
|
||||
self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8])
|
||||
|
||||
Reference in New Issue
Block a user