Fix Gemma2IntegrationTest (#38492)

* fix

* fix

* skip-ci

* skip-ci

* skip-ci

* skip-ci

* skip-ci

* skip-ci

* skip-ci

* skip-ci

* skip-ci

* skip-ci

* skip-ci

* update

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2025-06-02 22:45:09 +02:00
committed by GitHub
parent 1094dd34f7
commit ccc859620a
3 changed files with 61 additions and 18 deletions

View File

@@ -24,6 +24,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, Cohere2Config, is_
from transformers.generation.configuration_utils import GenerationConfig
from transformers.testing_utils import (
Expectations,
is_flash_attn_2_available,
require_flash_attn,
require_read_token,
require_torch,
@@ -282,6 +283,9 @@ class Cohere2IntegrationTest(unittest.TestCase):
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
Outputs for every attention functions should be coherent and identical.
"""
if attn_implementation == "flash_attention_2" and not is_flash_attn_2_available():
self.skipTest("FlashAttention2 is required for this test.")
if torch_device == "xpu" and attn_implementation == "flash_attention_2":
self.skipTest(reason="Intel XPU doesn't support falsh_attention_2 as of now.")