Universal Assisted Generation: Assisted generation with any assistant model (by Intel Labs) (#33383)
* Update candidate_generator.py * Update utils.py * add lookbehind params to _get_candidate_generator * make fixup * add unit tests * fix failing tests * add docstrings * fix docstrings; remove non-optimized AnyTokenizer * added any tokenizer generation correctness test * make fixup * fix assertion syntax * PR review fixes * address additional PR comments * fix tests * remove stropping criteria arg * make fixup * add AssistantConfig * fix prev_tokens branching * pass tokenizers through `generate()`kwargs * fix lookbehind values; tokenizer params WIP * fixup * AssistantConfig * remove AssistantConfig; apply PR suggestions * restructure tests * fixup * fix assistant_tokenizer arg validation * fixup * fix tests in TestAssistedCandidateGeneratorDifferentTokenizers * fix class docstring * PR suggestions * doc * doc update and improvements to `_validate_assistant()` --------- Co-authored-by: mosheber <moshe.berchansky@intel.com>
This commit is contained in:
@@ -88,6 +88,7 @@ if is_torch_available():
|
||||
WatermarkDetector,
|
||||
WatermarkingConfig,
|
||||
)
|
||||
from transformers.generation.candidate_generator import AssistedCandidateGeneratorDifferentTokenizers
|
||||
from transformers.generation.utils import _speculative_sampling
|
||||
|
||||
|
||||
@@ -3510,6 +3511,34 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertTrue(test_bos_id == gen_output[0, 0])
|
||||
self.assertTrue(generation_config.bos_token_id is None)
|
||||
|
||||
def test_speculative_decoding_equals_regular_decoding(self):
|
||||
draft_name = "double7/vicuna-68m"
|
||||
target_name = "Qwen/Qwen2-0.5B-Instruct"
|
||||
|
||||
draft_model = AutoModelForCausalLM.from_pretrained(draft_name)
|
||||
target_model = AutoModelForCausalLM.from_pretrained(target_name)
|
||||
|
||||
assistant_tokenizer = AutoTokenizer.from_pretrained(draft_name)
|
||||
target_tokenizer = AutoTokenizer.from_pretrained(target_name)
|
||||
|
||||
prompt_size = torch.randint(low=20, high=100, size=(1,))
|
||||
max_new_tokens = torch.randint(low=10, high=50, size=(1,))
|
||||
input_ids = (torch.rand(1, prompt_size[0]) * 100).to(int) + 50
|
||||
|
||||
max_new_tokens_item = max_new_tokens[0].item()
|
||||
expected_out = target_model.generate(input_ids, do_sample=False, max_new_tokens=max_new_tokens_item)
|
||||
predicted_out = target_model.generate(
|
||||
input_ids,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_new_tokens_item,
|
||||
assistant_model=draft_model,
|
||||
target_tokenizer=target_tokenizer,
|
||||
assistant_tokenizer=assistant_tokenizer,
|
||||
)
|
||||
|
||||
self.assertEqual(expected_out.shape, predicted_out.shape)
|
||||
self.assertTrue((expected_out == predicted_out).all().item())
|
||||
|
||||
@pytest.mark.generate
|
||||
@require_torch_multi_gpu
|
||||
def test_generate_with_static_cache_multi_gpu(self):
|
||||
@@ -3884,3 +3913,41 @@ class TokenHealingTestCase(unittest.TestCase):
|
||||
# bos_token_id is required when no input ids nor inputs_embeds is passed
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(max_length=20, bos_token_id=None)
|
||||
|
||||
|
||||
class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase):
|
||||
def test_no_intersection(self):
|
||||
prompt = np.array([[1, 2, 3]])
|
||||
prompt_plus_new_tokens = np.array([[4, 5, 6]])
|
||||
result = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt, prompt_plus_new_tokens)
|
||||
self.assertEqual(result, (None, None, None))
|
||||
|
||||
def test_complete_overlap(self):
|
||||
prompt = np.array([[1, 2, 3]])
|
||||
prompt_plus_new_tokens = np.array([[1, 2, 3, 4, 5]])
|
||||
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
|
||||
prompt, prompt_plus_new_tokens
|
||||
)
|
||||
self.assertEqual(discrep_length, 0)
|
||||
np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]]))
|
||||
np.testing.assert_array_equal(discrep_only, np.array([[]]))
|
||||
|
||||
def test_partial_overlap(self):
|
||||
prompt = np.array([[1, 2, 3]])
|
||||
prompt_plus_new_tokens = np.array([[2, 3, 4, 5]])
|
||||
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
|
||||
prompt, prompt_plus_new_tokens
|
||||
)
|
||||
self.assertEqual(discrep_length, 0)
|
||||
np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]]))
|
||||
np.testing.assert_array_equal(discrep_only, np.array([[]]))
|
||||
|
||||
def test_no_new_tokens(self):
|
||||
prompt = np.array([[1, 2, 3]])
|
||||
prompt_plus_new_tokens = np.array([[1, 2, 3]])
|
||||
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
|
||||
prompt, prompt_plus_new_tokens
|
||||
)
|
||||
self.assertEqual(discrep_length, 0)
|
||||
np.testing.assert_array_equal(new_tokens_only, np.array([[]]))
|
||||
np.testing.assert_array_equal(discrep_only, np.array([[]]))
|
||||
|
||||
Reference in New Issue
Block a user