Adaptive dynamic number of speculative tokens (#34156)

* initial commit

* update strategy

* add tradeoff FPR TPR with cost

* all probs

* fix

* fix

* fix style

* Update src/transformers/generation/configuration_utils.py

shorter docstring

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* import guard

* fix style

* add is_sklearn_available condition

* vectorizing to flatten the for-loop

* fix style

* disable adaptation for UAG

* update doc

* add TestAssistedCandidateGeneratorUpdateStrategy

* fix style

* protect import

* fix style

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Jonathan Mamou
2024-12-05 18:07:33 +02:00
committed by GitHub
parent b0a51e5cff
commit e27465c801
4 changed files with 177 additions and 2 deletions

View File

@@ -92,9 +92,16 @@ if is_torch_available():
WatermarkDetector,
WatermarkingConfig,
)
from transformers.generation.candidate_generator import AssistedCandidateGeneratorDifferentTokenizers
from transformers.generation.candidate_generator import (
AssistedCandidateGenerator,
AssistedCandidateGeneratorDifferentTokenizers,
)
from transformers.generation.utils import _speculative_sampling
from unittest.mock import patch
from transformers.utils import is_sklearn_available
class GenerationTesterMixin:
input_name = "input_ids"
@@ -4312,3 +4319,110 @@ class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase):
self.assertEqual(discrep_length, 0)
np.testing.assert_array_equal(new_tokens_only, np.array([[]]))
np.testing.assert_array_equal(discrep_only, np.array([[]]))
class TestAssistedCandidateGeneratorUpdateStrategy(unittest.TestCase):
def setUp(self):
checkpoint = "EleutherAI/pythia-160m-deduped"
self.assistant_model = AutoModelForCausalLM.from_pretrained(checkpoint)
self.assistant_model.generation_config.assistant_confidence_threshold = 0.4
self.model_kwargs = {}
self.input_ids = torch.randint(1, 10, (1, 9))
self.candidate_generator = AssistedCandidateGenerator(
input_ids=self.input_ids,
assistant_model=self.assistant_model,
generation_config=self.assistant_model.generation_config,
model_kwargs=self.model_kwargs,
)
self.candidate_generator.probs = [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
self.original_probs = self.candidate_generator.probs
self.original_threshold = self.assistant_model.generation_config.assistant_confidence_threshold
def assert_no_sklearn(self):
with patch("transformers.utils.import_utils._sklearn_available", False):
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
self.assertEqual(self.candidate_generator.matches, self.original_matches)
self.assertEqual(self.candidate_generator.probs, self.original_probs)
self.assertEqual(
self.assistant_model.generation_config.assistant_confidence_threshold, self.original_threshold
)
@parameterized.expand([(is_sklearn_available(),), (False,)])
def test_update_candidate_strategy_no_matches_short(self, sklearn_available):
print("test_update_candidate_strategy_no_matches_short")
self.original_matches = []
self.candidate_generator.matches = self.original_matches
self.num_matches = 0
if sklearn_available:
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
self.assertEqual(self.candidate_generator.matches, [0])
self.assertEqual(self.candidate_generator.probs, [0.9])
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.4)
else:
self.assert_no_sklearn()
@parameterized.expand([(is_sklearn_available(),), (False,)])
def test_update_candidate_strategy_with_mix_matches_3(self, sklearn_available):
self.original_matches = [1, 0, 1, 0, 1]
self.candidate_generator.matches = self.original_matches
self.num_matches = 3
if sklearn_available:
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
self.assertEqual(self.candidate_generator.matches, [1, 0, 1, 0, 1, 1, 1, 1, 0])
self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.2)
else:
self.assert_no_sklearn()
@parameterized.expand([(is_sklearn_available(),), (False,)])
def test_update_candidate_strategy_with_matches_4(self, sklearn_available):
self.original_matches = [1, 1, 1, 1, 1]
self.candidate_generator.matches = self.original_matches
self.num_matches = 4
if sklearn_available:
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 1, 1, 1])
self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.4)
else:
self.assert_no_sklearn()
@parameterized.expand([(is_sklearn_available(),), (False,)])
def test_update_candidate_strategy_with_matches_3(self, sklearn_available):
self.original_matches = [1, 1, 1, 1, 1]
self.candidate_generator.matches = self.original_matches
self.num_matches = 3
if sklearn_available:
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 1, 1, 0])
self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.2)
else:
self.assert_no_sklearn()
@parameterized.expand([(is_sklearn_available(),), (False,)])
def test_update_candidate_strategy_with_matches_2(self, sklearn_available):
self.original_matches = [1, 1, 1, 1, 1]
self.candidate_generator.matches = self.original_matches
self.num_matches = 2
if sklearn_available:
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 1, 0])
self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2])
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.3)
else:
self.assert_no_sklearn()
@parameterized.expand([(is_sklearn_available(),), (False,)])
def test_update_candidate_strategy_with_matches_1(self, sklearn_available):
self.original_matches = [1, 1, 1, 1, 1]
self.candidate_generator.matches = self.original_matches
self.num_matches = 1
if sklearn_available:
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 0])
self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3])
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.4)
else:
self.assert_no_sklearn()