Rename _supports_flash_attn_2 in examples and tests (#39471)
* delete `_supports_flash_attn_2` from examples and tests * simplify docs
This commit is contained in:
committed by
GitHub
parent
3a152e3a5c
commit
8c102e2eb1
@@ -3471,9 +3471,7 @@ class ModelTesterMixin:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if (attn_implementation == "flash_attention_2" and not model_class._supports_flash_attn_2) or (
|
||||
attn_implementation == "flash_attention_3" and not model_class._supports_flash_attn_3
|
||||
):
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -3969,22 +3967,12 @@ class ModelTesterMixin:
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
|
||||
|
||||
sub_models_supporting_fa = [
|
||||
(
|
||||
module._supports_flash_attn_3
|
||||
if attn_implementation == "flash_attention_3"
|
||||
else module._supports_flash_attn_2
|
||||
)
|
||||
module._supports_flash_attn
|
||||
for name, module in model.named_modules()
|
||||
if isinstance(module, PreTrainedModel) and name != ""
|
||||
]
|
||||
supports_fa_all_modules = (
|
||||
all(sub_models_supporting_fa)
|
||||
if len(sub_models_supporting_fa) > 0
|
||||
else (
|
||||
model._supports_flash_attn_3
|
||||
if attn_implementation == "flash_attention_3"
|
||||
else model._supports_flash_attn_2
|
||||
)
|
||||
all(sub_models_supporting_fa) if len(sub_models_supporting_fa) > 0 else model._supports_flash_attn
|
||||
)
|
||||
if not supports_fa_all_modules:
|
||||
with self.assertRaises(ValueError):
|
||||
@@ -4037,7 +4025,7 @@ class ModelTesterMixin:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
@@ -4104,9 +4092,8 @@ class ModelTesterMixin:
|
||||
torch_dtype = torch.float16
|
||||
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
cls = self._torch_compile_train_cls
|
||||
model = cls(config).to(device=torch_device, dtype=torch_dtype)
|
||||
cls = self._torch_compile_train_cls # e.g. LlamaFroCausalLM
|
||||
model = cls(config, attn_implementation="flash_attention_2").to(device=torch_device, dtype=torch_dtype)
|
||||
|
||||
inputs = {
|
||||
"input_ids": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device),
|
||||
@@ -4268,9 +4255,7 @@ class ModelTesterMixin:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if (attn_implementation == "flash_attention_2" and not model_class._supports_flash_attn_2) or (
|
||||
attn_implementation == "flash_attention_3" and not model_class._supports_flash_attn_3
|
||||
):
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user