Adaptive dynamic number of speculative tokens (#34156)
* initial commit * update strategy * add tradeoff FPR TPR with cost * all probs * fix * fix * fix style * Update src/transformers/generation/configuration_utils.py shorter docstring Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * import guard * fix style * add is_sklearn_available condition * vectorizing to flatten the for-loop * fix style * disable adaptation for UAG * update doc * add TestAssistedCandidateGeneratorUpdateStrategy * fix style * protect import * fix style --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
@@ -92,9 +92,16 @@ if is_torch_available():
|
||||
WatermarkDetector,
|
||||
WatermarkingConfig,
|
||||
)
|
||||
from transformers.generation.candidate_generator import AssistedCandidateGeneratorDifferentTokenizers
|
||||
from transformers.generation.candidate_generator import (
|
||||
AssistedCandidateGenerator,
|
||||
AssistedCandidateGeneratorDifferentTokenizers,
|
||||
)
|
||||
from transformers.generation.utils import _speculative_sampling
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from transformers.utils import is_sklearn_available
|
||||
|
||||
|
||||
class GenerationTesterMixin:
|
||||
input_name = "input_ids"
|
||||
@@ -4312,3 +4319,110 @@ class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase):
|
||||
self.assertEqual(discrep_length, 0)
|
||||
np.testing.assert_array_equal(new_tokens_only, np.array([[]]))
|
||||
np.testing.assert_array_equal(discrep_only, np.array([[]]))
|
||||
|
||||
|
||||
class TestAssistedCandidateGeneratorUpdateStrategy(unittest.TestCase):
|
||||
def setUp(self):
|
||||
checkpoint = "EleutherAI/pythia-160m-deduped"
|
||||
self.assistant_model = AutoModelForCausalLM.from_pretrained(checkpoint)
|
||||
self.assistant_model.generation_config.assistant_confidence_threshold = 0.4
|
||||
self.model_kwargs = {}
|
||||
self.input_ids = torch.randint(1, 10, (1, 9))
|
||||
self.candidate_generator = AssistedCandidateGenerator(
|
||||
input_ids=self.input_ids,
|
||||
assistant_model=self.assistant_model,
|
||||
generation_config=self.assistant_model.generation_config,
|
||||
model_kwargs=self.model_kwargs,
|
||||
)
|
||||
self.candidate_generator.probs = [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
|
||||
self.original_probs = self.candidate_generator.probs
|
||||
self.original_threshold = self.assistant_model.generation_config.assistant_confidence_threshold
|
||||
|
||||
def assert_no_sklearn(self):
|
||||
with patch("transformers.utils.import_utils._sklearn_available", False):
|
||||
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
|
||||
self.assertEqual(self.candidate_generator.matches, self.original_matches)
|
||||
self.assertEqual(self.candidate_generator.probs, self.original_probs)
|
||||
self.assertEqual(
|
||||
self.assistant_model.generation_config.assistant_confidence_threshold, self.original_threshold
|
||||
)
|
||||
|
||||
@parameterized.expand([(is_sklearn_available(),), (False,)])
|
||||
def test_update_candidate_strategy_no_matches_short(self, sklearn_available):
|
||||
print("test_update_candidate_strategy_no_matches_short")
|
||||
self.original_matches = []
|
||||
self.candidate_generator.matches = self.original_matches
|
||||
self.num_matches = 0
|
||||
|
||||
if sklearn_available:
|
||||
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
|
||||
self.assertEqual(self.candidate_generator.matches, [0])
|
||||
self.assertEqual(self.candidate_generator.probs, [0.9])
|
||||
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.4)
|
||||
else:
|
||||
self.assert_no_sklearn()
|
||||
|
||||
@parameterized.expand([(is_sklearn_available(),), (False,)])
|
||||
def test_update_candidate_strategy_with_mix_matches_3(self, sklearn_available):
|
||||
self.original_matches = [1, 0, 1, 0, 1]
|
||||
self.candidate_generator.matches = self.original_matches
|
||||
self.num_matches = 3
|
||||
if sklearn_available:
|
||||
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
|
||||
self.assertEqual(self.candidate_generator.matches, [1, 0, 1, 0, 1, 1, 1, 1, 0])
|
||||
self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])
|
||||
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.2)
|
||||
else:
|
||||
self.assert_no_sklearn()
|
||||
|
||||
@parameterized.expand([(is_sklearn_available(),), (False,)])
|
||||
def test_update_candidate_strategy_with_matches_4(self, sklearn_available):
|
||||
self.original_matches = [1, 1, 1, 1, 1]
|
||||
self.candidate_generator.matches = self.original_matches
|
||||
self.num_matches = 4
|
||||
if sklearn_available:
|
||||
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
|
||||
self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 1, 1, 1])
|
||||
self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])
|
||||
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.4)
|
||||
else:
|
||||
self.assert_no_sklearn()
|
||||
|
||||
@parameterized.expand([(is_sklearn_available(),), (False,)])
|
||||
def test_update_candidate_strategy_with_matches_3(self, sklearn_available):
|
||||
self.original_matches = [1, 1, 1, 1, 1]
|
||||
self.candidate_generator.matches = self.original_matches
|
||||
self.num_matches = 3
|
||||
if sklearn_available:
|
||||
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
|
||||
self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 1, 1, 0])
|
||||
self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])
|
||||
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.2)
|
||||
else:
|
||||
self.assert_no_sklearn()
|
||||
|
||||
@parameterized.expand([(is_sklearn_available(),), (False,)])
|
||||
def test_update_candidate_strategy_with_matches_2(self, sklearn_available):
|
||||
self.original_matches = [1, 1, 1, 1, 1]
|
||||
self.candidate_generator.matches = self.original_matches
|
||||
self.num_matches = 2
|
||||
if sklearn_available:
|
||||
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
|
||||
self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 1, 0])
|
||||
self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2])
|
||||
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.3)
|
||||
else:
|
||||
self.assert_no_sklearn()
|
||||
|
||||
@parameterized.expand([(is_sklearn_available(),), (False,)])
|
||||
def test_update_candidate_strategy_with_matches_1(self, sklearn_available):
|
||||
self.original_matches = [1, 1, 1, 1, 1]
|
||||
self.candidate_generator.matches = self.original_matches
|
||||
self.num_matches = 1
|
||||
if sklearn_available:
|
||||
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
|
||||
self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 0])
|
||||
self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3])
|
||||
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.4)
|
||||
else:
|
||||
self.assert_no_sklearn()
|
||||
|
||||
Reference in New Issue
Block a user