[FA2] Fix flash attention 2 fine-tuning with Falcon (#26852)

fix fa2 + dropout issue
This commit is contained in:
Younes Belkada
2023-10-17 15:38:03 +02:00
committed by GitHub
parent 4b423e6074
commit 41c42f85f6
2 changed files with 5 additions and 1 deletions

View File

@@ -2810,6 +2810,10 @@ class ModelTesterMixin:
self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2))
# check with inference + dropout
model.train()
_ = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test