Fix flash attention bugs with Mistral and Falcon (#27625)

* fix various bugs with flash attention

* bump

* fix test

* fix mistral

* use skiptest instead of return that may be misleading

* fix on review
This commit is contained in:
fxmarty
2023-11-21 15:20:44 +01:00
committed by GitHub
parent f93c1e9ece
commit 82cc0a79ac
5 changed files with 50 additions and 32 deletions

View File

@@ -22,6 +22,7 @@ from parameterized import parameterized
from transformers import LlamaConfig, is_torch_available, set_seed
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
require_torch,
require_torch_accelerator,
@@ -385,6 +386,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_padding_right(self):