Dynamic number of speculative tokens in order to accelerate speculative decoding (#33258)

* optimal Speculation Lookahead based on probability

* update peer finished condition

* add support to do_sample True

* add stopping criteria

* gitignore

* add print

* remove prints

* minor

* minor

* git ignore

* adding test to stopping ConfidenceCriteria

* doc + format

* add doc

* Update .gitignore

* update docstring and default value of assistant_confidence_threshold

* add docstring

* Update src/transformers/generation/configuration_utils.py

implicit default value (None)

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* style fix

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Jonathan Mamou
2024-09-11 15:22:28 +03:00
committed by GitHub
parent 42babe8548
commit 7a51cbc65f
6 changed files with 57 additions and 0 deletions

View File

@@ -26,6 +26,7 @@ if is_torch_available():
import torch
from transformers.generation import (
ConfidenceCriteria,
EosTokenCriteria,
MaxLengthCriteria,
MaxTimeCriteria,
@@ -100,6 +101,23 @@ class StoppingCriteriaTestCase(unittest.TestCase):
input_ids[:, -1] = 1
self.assertListEqual(criteria(input_ids, scores).tolist(), [False, False, False])
def test_confidence_criteria(self):
criteria = ConfidenceCriteria(assistant_confidence_threshold=0.5)
vocab_size = 250
length = 5
input_ids = ids_tensor((1, length), vocab_size)
scores = (torch.randn((1, vocab_size)),)
# Simulate high confidence by setting the probability of the last token to be high
scores[0][0, input_ids[0, -1]] = 10.0 # Logits before softmax
self.assertFalse(criteria(input_ids, scores))
# Simulate low confidence by setting the probability of the last token to be low
scores[0][0, input_ids[0, -1]] = -10.0 # Logits before softmax
self.assertTrue(criteria(input_ids, scores))
def test_validate_stopping_criteria(self):
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10)