Re-enable SDPA's FA2 path (#30070)

* tentatively re-enable FA2 + SDPA

* better comment

* _ignore_causal_mask_sdpa as staticmethod

* type hints

* use past_seen_tokens instead

* enable copied from for sdpa

* ruff

* llama simplifications on review

* remove unnecessary self.is_causal check

* fix copies

* cleaning

* precise message

* better doc

* add test

* simplify

* Update src/transformers/models/llama/modeling_llama.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_llama.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_llama.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* style

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
fxmarty
2024-04-17 22:21:00 +02:00
committed by GitHub
parent e4ea19b958
commit 05bdef16b6
5 changed files with 176 additions and 60 deletions

View File

@@ -3772,6 +3772,42 @@ class ModelTesterMixin:
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
@require_torch_sdpa
@require_torch_gpu
@slow
def test_sdpa_can_dispatch_on_flash(self):
compute_capability = torch.cuda.get_device_capability()
major, _ = compute_capability
if not torch.version.cuda or major < 8:
self.skipTest("This test requires an NVIDIA GPU with compute capability >= 8.0")
for model_class in self.all_model_classes:
if not model_class._supports_sdpa:
self.skipTest(f"{model_class.__name__} does not support SDPA")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.model_type in ["llava", "llava_next", "vipllava"]:
self.skipTest("Llava-like models currently (transformers==4.39.1) requires an attention_mask input")
if config.model_type in ["idefics"]:
self.skipTest("Idefics currently (transformers==4.39.1) requires an image_attention_mask input")
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa")
model.to(torch_device)
inputs_dict.pop("attention_mask", None)
inputs_dict.pop("decoder_attention_mask", None)
for name, inp in inputs_dict.items():
if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]:
inputs_dict[name] = inp.to(torch.float16)
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
_ = model(**inputs_dict)
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):