From 153755ee386ac73e04814a94337abcb1208ff5d1 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 27 Sep 2023 12:21:54 +0200 Subject: [PATCH] [`FA` / `tests`] Add use_cache tests for FA models (#26415) * add use_cache tests for FA * fixup --- tests/test_modeling_common.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index b8d7367dd7..8c2a277b4b 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2908,6 +2908,35 @@ class ModelTesterMixin: self.assertTrue(torch.equal(out, out_fa)) + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_generate_use_cache(self): + import torch + + for model_class in self.all_generative_model_classes: + if not model_class._supports_flash_attn_2: + return + + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) + dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) + + model = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True + ).to(torch_device) + + # Just test that a large cache works as expected + _ = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=30, do_sample=False + ) + global_rng = random.Random()