Self-speculation (Layer-Skip Llama) (#34240)

* 😅

* early exit (#34244)

* mvp

* docs and tests

* a few fixes

* no shared cache

* Apply suggestions from code review

Co-authored-by: Mostafa Elhoushi <m.elhoushi@ieee.org>

* docs

* make fix-copies

* cohere fix

* [test all]

* [test all] consistent model code copies

* [test all] make fix-copies :D

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Mostafa Elhoushi <m.elhoushi@ieee.org>

* Update src/transformers/generation/candidate_generator.py

* Update src/transformers/generation/configuration_utils.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* [test all] don't use a stand-alone attribute; fix test

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Joao Gante <joao@huggingface.co>
Co-authored-by: Mostafa Elhoushi <m.elhoushi@ieee.org>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
Arthur
2024-11-19 13:20:07 +01:00
committed by GitHub
parent 5de58d5955
commit 54739a320e
15 changed files with 185 additions and 58 deletions

View File

@@ -4108,6 +4108,28 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
gen_out = compiled_generate(**model_inputs, generation_config=generation_config)
self.assertTrue(gen_out.shape[1] > model_inputs["input_ids"].shape[1]) # some text was generated
def test_assisted_generation_early_exit(self):
"""
Tests that assisted generation with early exit works as expected. Under the hood, this has complex cache
manipulation, which will cause the test to fail if something goes wrong there.
"""
expected_output = "Alice and Bob are playing a game of poker. Alice has a pair of 8s and Bob has a pair"
prompt = "Alice and Bob"
checkpoint = "facebook/layerskip-llama3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
inputs = tokenizer(prompt, return_tensors="pt").to(torch_device)
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(torch_device)
original_outputs = model.generate(**inputs, do_sample=False, max_new_tokens=20)
original_decoded = tokenizer.batch_decode(original_outputs, skip_special_tokens=True)
self.assertEqual(original_decoded, [expected_output])
outputs_assisted = model.generate(**inputs, assistant_early_exit=4, do_sample=False, max_new_tokens=20)
decoded_assisted = tokenizer.batch_decode(outputs_assisted, skip_special_tokens=True)
self.assertEqual(decoded_assisted, [expected_output])
@require_torch
class TokenHealingTestCase(unittest.TestCase):