[FA2] Fix flash attention 2 fine-tuning with Falcon (#26852)
fix fa2 + dropout issue
This commit is contained in:
@@ -2810,6 +2810,10 @@ class ModelTesterMixin:
|
||||
|
||||
self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2))
|
||||
|
||||
# check with inference + dropout
|
||||
model.train()
|
||||
_ = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
|
||||
Reference in New Issue
Block a user