Universal Assisted Generation: Assisted generation with any assistant model (by Intel Labs) (#33383)

* Update candidate_generator.py

* Update utils.py

* add lookbehind params to _get_candidate_generator

* make fixup

* add unit tests

* fix failing tests

* add docstrings

* fix docstrings; remove non-optimized AnyTokenizer

* added any tokenizer generation correctness test

* make fixup

* fix assertion syntax

* PR review fixes

* address additional PR comments

* fix tests

* remove stropping criteria arg

* make fixup

* add AssistantConfig

* fix prev_tokens branching

* pass tokenizers through `generate()`kwargs

* fix lookbehind values; tokenizer params WIP

* fixup

* AssistantConfig

* remove AssistantConfig; apply PR suggestions

* restructure tests

* fixup

* fix assistant_tokenizer arg validation

* fixup

* fix tests in TestAssistedCandidateGeneratorDifferentTokenizers

* fix class docstring

* PR suggestions

* doc

* doc update and improvements to `_validate_assistant()`

---------

Co-authored-by: mosheber <moshe.berchansky@intel.com>
This commit is contained in:
Daniel Korat
2024-10-10 15:41:53 +03:00
committed by GitHub
parent dda3f91d06
commit fb0c6b521d
4 changed files with 440 additions and 10 deletions

View File

@@ -88,6 +88,7 @@ if is_torch_available():
WatermarkDetector,
WatermarkingConfig,
)
from transformers.generation.candidate_generator import AssistedCandidateGeneratorDifferentTokenizers
from transformers.generation.utils import _speculative_sampling
@@ -3510,6 +3511,34 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertTrue(test_bos_id == gen_output[0, 0])
self.assertTrue(generation_config.bos_token_id is None)
def test_speculative_decoding_equals_regular_decoding(self):
draft_name = "double7/vicuna-68m"
target_name = "Qwen/Qwen2-0.5B-Instruct"
draft_model = AutoModelForCausalLM.from_pretrained(draft_name)
target_model = AutoModelForCausalLM.from_pretrained(target_name)
assistant_tokenizer = AutoTokenizer.from_pretrained(draft_name)
target_tokenizer = AutoTokenizer.from_pretrained(target_name)
prompt_size = torch.randint(low=20, high=100, size=(1,))
max_new_tokens = torch.randint(low=10, high=50, size=(1,))
input_ids = (torch.rand(1, prompt_size[0]) * 100).to(int) + 50
max_new_tokens_item = max_new_tokens[0].item()
expected_out = target_model.generate(input_ids, do_sample=False, max_new_tokens=max_new_tokens_item)
predicted_out = target_model.generate(
input_ids,
do_sample=False,
max_new_tokens=max_new_tokens_item,
assistant_model=draft_model,
target_tokenizer=target_tokenizer,
assistant_tokenizer=assistant_tokenizer,
)
self.assertEqual(expected_out.shape, predicted_out.shape)
self.assertTrue((expected_out == predicted_out).all().item())
@pytest.mark.generate
@require_torch_multi_gpu
def test_generate_with_static_cache_multi_gpu(self):
@@ -3884,3 +3913,41 @@ class TokenHealingTestCase(unittest.TestCase):
# bos_token_id is required when no input ids nor inputs_embeds is passed
with self.assertRaises(ValueError):
model.generate(max_length=20, bos_token_id=None)
class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase):
def test_no_intersection(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[4, 5, 6]])
result = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt, prompt_plus_new_tokens)
self.assertEqual(result, (None, None, None))
def test_complete_overlap(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[1, 2, 3, 4, 5]])
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
prompt, prompt_plus_new_tokens
)
self.assertEqual(discrep_length, 0)
np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]]))
np.testing.assert_array_equal(discrep_only, np.array([[]]))
def test_partial_overlap(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[2, 3, 4, 5]])
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
prompt, prompt_plus_new_tokens
)
self.assertEqual(discrep_length, 0)
np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]]))
np.testing.assert_array_equal(discrep_only, np.array([[]]))
def test_no_new_tokens(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[1, 2, 3]])
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
prompt, prompt_plus_new_tokens
)
self.assertEqual(discrep_length, 0)
np.testing.assert_array_equal(new_tokens_only, np.array([[]]))
np.testing.assert_array_equal(discrep_only, np.array([[]]))