[FA2] Fix flash attention 2 fine-tuning with Falcon (#26852)
fix fa2 + dropout issue
This commit is contained in:
@@ -606,7 +606,7 @@ class FalconFlashAttention2(FalconAttention):
|
||||
if alibi is not None:
|
||||
raise ValueError("`alibi` is not supported when `use_flash_attn` is True")
|
||||
|
||||
attn_dropout = self.attention_dropout if self.training else 0.0
|
||||
attn_dropout = self.config.attention_dropout if self.training else 0.0
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
|
||||
@@ -2810,6 +2810,10 @@ class ModelTesterMixin:
|
||||
|
||||
self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2))
|
||||
|
||||
# check with inference + dropout
|
||||
model.train()
|
||||
_ = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
|
||||
Reference in New Issue
Block a user