From 82cc0a79ac796184806a137a7000ad1b2036fe5b Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 21 Nov 2023 15:20:44 +0100 Subject: [PATCH] Fix flash attention bugs with Mistral and Falcon (#27625) * fix various bugs with flash attention * bump * fix test * fix mistral * use skiptest instead of return that may be misleading * fix on review --- .../models/falcon/modeling_falcon.py | 6 +++ .../models/mistral/modeling_mistral.py | 2 +- tests/models/llama/test_modeling_llama.py | 2 + tests/models/mistral/test_modeling_mistral.py | 54 +++++++++++-------- tests/test_modeling_common.py | 18 +++---- 5 files changed, 50 insertions(+), 32 deletions(-) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index d4c647c846..e7538eb40b 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -564,6 +564,12 @@ class FalconFlashAttention2(FalconAttention): past_key_value = (key_layer, value_layer) if use_cache else None + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_layer = query_layer.transpose(1, 2) + key_layer = key_layer.transpose(1, 2) + value_layer = value_layer.transpose(1, 2) + if alibi is not None: raise ValueError("`alibi` is not supported when `use_flash_attn` is True") diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 3aefb03d8c..e56ebc0310 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -838,7 +838,7 @@ class MistralModel(MistralPreTrainedModel): attention_mask is not None and hasattr(self.config, "_flash_attn_2_enabled") and self.config._flash_attn_2_enabled - and past_key_values is not None + and use_cache ): is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 21fb4f44d2..55b36c7102 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -22,6 +22,7 @@ from parameterized import parameterized from transformers import LlamaConfig, is_torch_available, set_seed from transformers.testing_utils import ( + require_bitsandbytes, require_flash_attn, require_torch, require_torch_accelerator, @@ -385,6 +386,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi @require_flash_attn @require_torch_gpu + @require_bitsandbytes @pytest.mark.flash_attn_test @slow def test_flash_attn_2_generate_padding_right(self): diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index b30e70ba71..31426435d0 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -375,9 +375,6 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi import torch for model_class in self.all_generative_model_classes: - if not model_class._supports_flash_attn_2: - return - config, _ = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) @@ -405,36 +402,49 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_generate_use_cache(self): import torch - for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: - return + max_new_tokens = 30 + + for model_class in self.all_generative_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 - config, _ = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) - model_fa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True - ) - model_fa.to(torch_device) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # NOTE: Mistral apparently does not support right padding + use_cache with FA2. + dummy_attention_mask[:, -1] = 1 model = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False + tmpdirname, + torch_dtype=torch.float16, + use_flash_attention_2=True, + low_cpu_mem_usage=True, + ).to(torch_device) + + # Just test that a large cache works as expected + _ = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False ) - model.to(torch_device) - dummy_input = torch.LongTensor([[1, 2, 3, 4, 5]]).to(torch_device) - dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1, 0]]).to(torch_device) - - _ = model(dummy_input, output_hidden_states=True).hidden_states[-1] - with self.assertRaises(ValueError): - _ = model_fa( - dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True - ).hidden_states[-1] + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_inference_padding_right(self): + self.skipTest("Mistral flash attention does not support right padding") @require_torch diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 49d64dc207..9d9e96db43 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2835,7 +2835,7 @@ class ModelTesterMixin: for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: - return + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") model = model_class(config) @@ -2860,7 +2860,7 @@ class ModelTesterMixin: for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: - return + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) @@ -2957,7 +2957,7 @@ class ModelTesterMixin: for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: - return + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) @@ -3050,7 +3050,7 @@ class ModelTesterMixin: for model_class in self.all_generative_model_classes: if not model_class._supports_flash_attn_2: - return + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) @@ -3093,7 +3093,7 @@ class ModelTesterMixin: for model_class in self.all_generative_model_classes: if not model_class._supports_flash_attn_2: - return + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) @@ -3109,7 +3109,7 @@ class ModelTesterMixin: dummy_input = dummy_input.to(torch.float16) dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) - # make sure we do left padding + # make sure we do right padding dummy_attention_mask[:, :-1] = 1 dummy_attention_mask[:, -1:] = 0 @@ -3138,7 +3138,7 @@ class ModelTesterMixin: for model_class in self.all_generative_model_classes: if not model_class._supports_flash_attn_2: - return + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -3179,7 +3179,7 @@ class ModelTesterMixin: for model_class in self.all_generative_model_classes: if not model_class._supports_flash_attn_2: - return + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) @@ -3279,7 +3279,7 @@ class ModelTesterMixin: for model_class in self.all_generative_model_classes: if not model_class._supports_flash_attn_2: - return + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, _ = self.model_tester.prepare_config_and_inputs_for_common() # TODO: to change it in the future with other relevant auto classes