[tests] fix static cache implementation is not compatible with attn_implementation==flash_attention_2 (#32039)
* add flash attention check * fix * fix
This commit is contained in:
@@ -290,7 +290,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
|||||||
self.assertTrue(decoded[0].endswith(last_output))
|
self.assertTrue(decoded[0].endswith(last_output))
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@parameterized.expand(["eager", "sdpa", "flash_attention_2"])
|
@parameterized.expand(["eager", "sdpa"])
|
||||||
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation):
|
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation):
|
||||||
EXPECTED_GENERATION = [
|
EXPECTED_GENERATION = [
|
||||||
"The best color is the one that complements the skin tone of the",
|
"The best color is the one that complements the skin tone of the",
|
||||||
@@ -330,7 +330,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
|||||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@parameterized.expand(["eager", "sdpa", "flash_attention_2"])
|
@parameterized.expand(["eager", "sdpa"])
|
||||||
def test_static_cache_greedy_decoding_pad_right(self, attn_implementation):
|
def test_static_cache_greedy_decoding_pad_right(self, attn_implementation):
|
||||||
EXPECTED_GENERATION = [
|
EXPECTED_GENERATION = [
|
||||||
"The best color isЋ the one that complements the skin tone of",
|
"The best color isЋ the one that complements the skin tone of",
|
||||||
|
|||||||
Reference in New Issue
Block a user