Fix flash attention bugs with Mistral and Falcon (#27625)
* fix various bugs with flash attention * bump * fix test * fix mistral * use skiptest instead of return that may be misleading * fix on review
This commit is contained in:
@@ -2835,7 +2835,7 @@ class ModelTesterMixin:
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
return
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
@@ -2860,7 +2860,7 @@ class ModelTesterMixin:
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
return
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
@@ -2957,7 +2957,7 @@ class ModelTesterMixin:
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
return
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
@@ -3050,7 +3050,7 @@ class ModelTesterMixin:
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
return
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
@@ -3093,7 +3093,7 @@ class ModelTesterMixin:
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
return
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
@@ -3109,7 +3109,7 @@ class ModelTesterMixin:
|
||||
dummy_input = dummy_input.to(torch.float16)
|
||||
|
||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
||||
# make sure we do left padding
|
||||
# make sure we do right padding
|
||||
dummy_attention_mask[:, :-1] = 1
|
||||
dummy_attention_mask[:, -1:] = 0
|
||||
|
||||
@@ -3138,7 +3138,7 @@ class ModelTesterMixin:
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
return
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -3179,7 +3179,7 @@ class ModelTesterMixin:
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
return
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
@@ -3279,7 +3279,7 @@ class ModelTesterMixin:
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
return
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
# TODO: to change it in the future with other relevant auto classes
|
||||
|
||||
Reference in New Issue
Block a user