[FA2] Fix flash attention 2 fine-tuning with Falcon (#26852)
fix fa2 + dropout issue
This commit is contained in:
@@ -606,7 +606,7 @@ class FalconFlashAttention2(FalconAttention):
|
|||||||
if alibi is not None:
|
if alibi is not None:
|
||||||
raise ValueError("`alibi` is not supported when `use_flash_attn` is True")
|
raise ValueError("`alibi` is not supported when `use_flash_attn` is True")
|
||||||
|
|
||||||
attn_dropout = self.attention_dropout if self.training else 0.0
|
attn_dropout = self.config.attention_dropout if self.training else 0.0
|
||||||
|
|
||||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
|
|||||||
@@ -2810,6 +2810,10 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2))
|
self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2))
|
||||||
|
|
||||||
|
# check with inference + dropout
|
||||||
|
model.train()
|
||||||
|
_ = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@mark.flash_attn_test
|
@mark.flash_attn_test
|
||||||
|
|||||||
Reference in New Issue
Block a user