Fix Gemma2IntegrationTest (#38492)
* fix * fix * skip-ci * skip-ci * skip-ci * skip-ci * skip-ci * skip-ci * skip-ci * skip-ci * skip-ci * skip-ci * skip-ci * update * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -24,6 +24,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, Cohere2Config, is_
|
||||
from transformers.generation.configuration_utils import GenerationConfig
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
is_flash_attn_2_available,
|
||||
require_flash_attn,
|
||||
require_read_token,
|
||||
require_torch,
|
||||
@@ -282,6 +283,9 @@ class Cohere2IntegrationTest(unittest.TestCase):
|
||||
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
|
||||
Outputs for every attention functions should be coherent and identical.
|
||||
"""
|
||||
if attn_implementation == "flash_attention_2" and not is_flash_attn_2_available():
|
||||
self.skipTest("FlashAttention2 is required for this test.")
|
||||
|
||||
if torch_device == "xpu" and attn_implementation == "flash_attention_2":
|
||||
self.skipTest(reason="Intel XPU doesn't support falsh_attention_2 as of now.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user