Self-speculation (Layer-Skip Llama) (#34240)
* 😅 * early exit (#34244) * mvp * docs and tests * a few fixes * no shared cache * Apply suggestions from code review Co-authored-by: Mostafa Elhoushi <m.elhoushi@ieee.org> * docs * make fix-copies * cohere fix * [test all] * [test all] consistent model code copies * [test all] make fix-copies :D * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Mostafa Elhoushi <m.elhoushi@ieee.org> * Update src/transformers/generation/candidate_generator.py * Update src/transformers/generation/configuration_utils.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * [test all] don't use a stand-alone attribute; fix test --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Joao Gante <joao@huggingface.co> Co-authored-by: Mostafa Elhoushi <m.elhoushi@ieee.org> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
@@ -4108,6 +4108,28 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
gen_out = compiled_generate(**model_inputs, generation_config=generation_config)
|
||||
self.assertTrue(gen_out.shape[1] > model_inputs["input_ids"].shape[1]) # some text was generated
|
||||
|
||||
def test_assisted_generation_early_exit(self):
|
||||
"""
|
||||
Tests that assisted generation with early exit works as expected. Under the hood, this has complex cache
|
||||
manipulation, which will cause the test to fail if something goes wrong there.
|
||||
"""
|
||||
expected_output = "Alice and Bob are playing a game of poker. Alice has a pair of 8s and Bob has a pair"
|
||||
|
||||
prompt = "Alice and Bob"
|
||||
checkpoint = "facebook/layerskip-llama3.2-1B"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(torch_device)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(torch_device)
|
||||
original_outputs = model.generate(**inputs, do_sample=False, max_new_tokens=20)
|
||||
original_decoded = tokenizer.batch_decode(original_outputs, skip_special_tokens=True)
|
||||
self.assertEqual(original_decoded, [expected_output])
|
||||
|
||||
outputs_assisted = model.generate(**inputs, assistant_early_exit=4, do_sample=False, max_new_tokens=20)
|
||||
decoded_assisted = tokenizer.batch_decode(outputs_assisted, skip_special_tokens=True)
|
||||
self.assertEqual(decoded_assisted, [expected_output])
|
||||
|
||||
|
||||
@require_torch
|
||||
class TokenHealingTestCase(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user