Fix flash attention bugs with Mistral and Falcon (#27625)

* fix various bugs with flash attention

* bump

* fix test

* fix mistral

* use skiptest instead of return that may be misleading

* fix on review
This commit is contained in:
fxmarty
2023-11-21 15:20:44 +01:00
committed by GitHub
parent f93c1e9ece
commit 82cc0a79ac
5 changed files with 50 additions and 32 deletions

View File

@@ -2835,7 +2835,7 @@ class ModelTesterMixin:
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
return
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
model = model_class(config)
@@ -2860,7 +2860,7 @@ class ModelTesterMixin:
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
return
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
@@ -2957,7 +2957,7 @@ class ModelTesterMixin:
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
return
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
@@ -3050,7 +3050,7 @@ class ModelTesterMixin:
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
return
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
@@ -3093,7 +3093,7 @@ class ModelTesterMixin:
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
return
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
@@ -3109,7 +3109,7 @@ class ModelTesterMixin:
dummy_input = dummy_input.to(torch.float16)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# make sure we do left padding
# make sure we do right padding
dummy_attention_mask[:, :-1] = 1
dummy_attention_mask[:, -1:] = 0
@@ -3138,7 +3138,7 @@ class ModelTesterMixin:
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
return
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -3179,7 +3179,7 @@ class ModelTesterMixin:
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
return
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
@@ -3279,7 +3279,7 @@ class ModelTesterMixin:
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
return
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# TODO: to change it in the future with other relevant auto classes