[Core generation] Adds support for static KV cache (#27931)

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Arthur
2024-02-08 19:50:34 +09:00
committed by GitHub
parent 4b236aed76
commit 115ac94d06
19 changed files with 474 additions and 232 deletions

View File

@@ -362,6 +362,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
pass
@parameterized.expand([("linear",), ("dynamic",)])
@unittest.skip("TODO @gante fix this for Llama")
def test_model_rope_scaling(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size)
@@ -507,9 +508,19 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device)
res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
self.assertTrue(torch.allclose(res_eager, res_sdpa))
with self.subTest(f"{padding_side}"):
torch.testing.assert_close(
res_eager,
res_sdpa,
msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}",
)
@unittest.skip("TODO @gante fix this for Llama")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
@require_torch