From 41c42f85f61b0c333e940c6d424fdfb81e180a7b Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 17 Oct 2023 15:38:03 +0200 Subject: [PATCH] [`FA2`] Fix flash attention 2 fine-tuning with Falcon (#26852) fix fa2 + dropout issue --- src/transformers/models/falcon/modeling_falcon.py | 2 +- tests/test_modeling_common.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 35313e8d9e..5fb155775a 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -606,7 +606,7 @@ class FalconFlashAttention2(FalconAttention): if alibi is not None: raise ValueError("`alibi` is not supported when `use_flash_attn` is True") - attn_dropout = self.attention_dropout if self.training else 0.0 + attn_dropout = self.config.attention_dropout if self.training else 0.0 # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 5a239cf0fb..019650a98e 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2810,6 +2810,10 @@ class ModelTesterMixin: self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)) + # check with inference + dropout + model.train() + _ = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test