Dynamic number of speculative tokens in order to accelerate speculative decoding (#33258)
* optimal Speculation Lookahead based on probability * update peer finished condition * add support to do_sample True * add stopping criteria * gitignore * add print * remove prints * minor * minor * git ignore * adding test to stopping ConfidenceCriteria * doc + format * add doc * Update .gitignore * update docstring and default value of assistant_confidence_threshold * add docstring * Update src/transformers/generation/configuration_utils.py implicit default value (None) Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * style fix --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
@@ -26,6 +26,7 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.generation import (
|
||||
ConfidenceCriteria,
|
||||
EosTokenCriteria,
|
||||
MaxLengthCriteria,
|
||||
MaxTimeCriteria,
|
||||
@@ -100,6 +101,23 @@ class StoppingCriteriaTestCase(unittest.TestCase):
|
||||
input_ids[:, -1] = 1
|
||||
self.assertListEqual(criteria(input_ids, scores).tolist(), [False, False, False])
|
||||
|
||||
def test_confidence_criteria(self):
|
||||
criteria = ConfidenceCriteria(assistant_confidence_threshold=0.5)
|
||||
|
||||
vocab_size = 250
|
||||
length = 5
|
||||
|
||||
input_ids = ids_tensor((1, length), vocab_size)
|
||||
scores = (torch.randn((1, vocab_size)),)
|
||||
|
||||
# Simulate high confidence by setting the probability of the last token to be high
|
||||
scores[0][0, input_ids[0, -1]] = 10.0 # Logits before softmax
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
|
||||
# Simulate low confidence by setting the probability of the last token to be low
|
||||
scores[0][0, input_ids[0, -1]] = -10.0 # Logits before softmax
|
||||
self.assertTrue(criteria(input_ids, scores))
|
||||
|
||||
def test_validate_stopping_criteria(self):
|
||||
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user