From 9470c00042d0ff37e52a6e442970547b42d29b6c Mon Sep 17 00:00:00 2001 From: Guang Yang <42389959+guangy10@users.noreply.github.com> Date: Thu, 17 Oct 2024 08:33:19 -0700 Subject: [PATCH] Llama3 and Llama2 are ExecuTorch compatible (#34101) Llama3_1b and Llama2_7b are ExecuTorch compatible Co-authored-by: Guang Yang --- tests/models/llama/test_modeling_llama.py | 69 +++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index d43a0fb13f..fe521ea410 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -23,6 +23,7 @@ from packaging import version from parameterized import parameterized from transformers import AutoTokenizer, LlamaConfig, StaticCache, is_torch_available, set_seed +from transformers.generation.configuration_utils import GenerationConfig from transformers.testing_utils import ( backend_empty_cache, require_bitsandbytes, @@ -916,6 +917,74 @@ class LlamaIntegrationTest(unittest.TestCase): static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) + @slow + @require_read_token + def test_export_static_cache(self): + if version.parse(torch.__version__) < version.parse("2.4.0"): + self.skipTest(reason="This test requires torch >= 2.4 to run.") + + from transformers.integrations.executorch import ( + TorchExportableModuleWithStaticCache, + convert_and_export_with_cache, + ) + + llama_models = { + "meta-llama/Llama-3.2-1B": [ + "Simply put, the theory of relativity states that 1) the speed of light is the same for all " + "observers, regardless of their location, and 2) the laws of physics are the same for all observers" + ], + "meta-llama/Llama-3.2-3B": [ + "Simply put, the theory of relativity states that 1. the speed of light is constant, and 2. " + "the speed of light is the fastest speed possible" + ], + "meta-llama/Llama-2-7b-hf": [ + "Simply put, the theory of relativity states that 1) the speed of light is a constant, and 2) " + "the laws of physics are the same for all", + ], + } + + for llama_model_ckp, EXPECTED_TEXT_COMPLETION in llama_models.items(): + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(llama_model_ckp, pad_token="", padding_side="right") + max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[ + "input_ids" + ].shape[-1] + + # Load model + device = "cpu" + dtype = torch.bfloat16 + cache_implementation = "static" + attn_implementation = "sdpa" + batch_size = 1 + model = LlamaForCausalLM.from_pretrained( + llama_model_ckp, + device_map=device, + torch_dtype=dtype, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_generation_length, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_generation_length, + }, + ), + ) + + prompts = ["Simply put, the theory of relativity states that "] + prompt_tokens = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + prompt_token_ids = prompt_tokens["input_ids"] + max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] + + # Static Cache + export + exported_program = convert_and_export_with_cache(model) + ep_generated_ids = TorchExportableModuleWithStaticCache.generate( + exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens + ) + ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text) + @slow @require_torch_accelerator