Support for Flash Attention 3 (#38972)
* Support `flash_attn_3` Implements fwd and tests for Flash Attention 3 https://github.com/Dao-AILab/flash-attention/commits/main/hopper - Includes checks for dropout>0 and ALiBi in `modeling_utils.PreTrainedModel._check_and_enable_flash_attn_3` (Dropout will likely be supported soon, so this will need to be updated and `modeling_flash_attention_utils._flash_attention_forward` at the `if _IS_FLASH_ATTN_3_AVAILABLE: ...` An example Llama implementation is included in `modeling_llama.py` but other models would still need to be updated Based on https://github.com/huggingface/transformers/pull/36190 which has model implementations and examples which could be merged * Add tests for Flash Attention 2 and 3 parity * ci fix * FA2 compatibiity - `_prepare_flash_attention_from_position_ids` ->`prepare_fa2_from_position_ids` - Remove bettertransformer check in Flash Attention 3 - Merge tests - Add licensing * ci fix * Test naming consistency * ci fix * Deprecation warning for `prepare_fa2_from_position_ids` * ci fix
This commit is contained in:
144
tests/generation/test_flash_attention_parity.py
Normal file
144
tests/generation/test_flash_attention_parity.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# Copyright 2025 Eduard Durech and SGLang team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Usage:
|
||||
# RUN_SLOW=1 pytest -s tests/generation/test_flash_attention_parity.py
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.testing_utils import require_flash_attn, require_flash_attn_3, require_torch_gpu, slow
|
||||
|
||||
|
||||
class FlashAttentionParityTest(unittest.TestCase):
|
||||
# From https://github.com/sgl-project/sglang/blob/main/python/sglang/test/test_utils.py
|
||||
def _lcs(self, X, Y):
|
||||
m = len(X)
|
||||
n = len(Y)
|
||||
L = [[0] * (n + 1) for _ in range(m + 1)]
|
||||
|
||||
for i in range(m + 1):
|
||||
for j in range(n + 1):
|
||||
if i == 0 or j == 0:
|
||||
L[i][j] = 0
|
||||
elif X[i - 1] == Y[j - 1]:
|
||||
L[i][j] = L[i - 1][j - 1] + 1
|
||||
else:
|
||||
L[i][j] = max(L[i - 1][j], L[i][j - 1])
|
||||
|
||||
return L[m][n]
|
||||
|
||||
# From https://github.com/sgl-project/sglang/blob/main/python/sglang/test/test_utils.py
|
||||
def _calculate_rouge_l(self, output_strs_list1, output_strs_list2):
|
||||
rouge_l_scores = []
|
||||
|
||||
for s1, s2 in zip(output_strs_list1, output_strs_list2):
|
||||
lcs_len = self._lcs(s1, s2)
|
||||
precision = lcs_len / len(s1) if len(s1) > 0 else 0
|
||||
recall = lcs_len / len(s2) if len(s2) > 0 else 0
|
||||
if precision + recall > 0:
|
||||
fmeasure = (2 * precision * recall) / (precision + recall)
|
||||
else:
|
||||
fmeasure = 0.0
|
||||
rouge_l_scores.append(fmeasure)
|
||||
|
||||
return rouge_l_scores
|
||||
|
||||
def _benchmark_generation(self, model, inputs, n_warmup=3, n_runs=5):
|
||||
for _ in range(n_warmup):
|
||||
model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_time = torch.cuda.Event(enable_timing=True)
|
||||
end_time = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_time.record()
|
||||
for _ in range(n_runs):
|
||||
model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
end_time.record()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return start_time.elapsed_time(end_time) / n_runs
|
||||
|
||||
@pytest.mark.flash_attn_3_test
|
||||
@require_torch_gpu
|
||||
@require_flash_attn
|
||||
@require_flash_attn_3
|
||||
@slow
|
||||
def test_flash_attention_2_3_parity(self):
|
||||
model_id = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
prompt = "The ETH AI Center is"
|
||||
|
||||
# 1. Load FA2 model and tokenizer
|
||||
model_2 = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
).to("cuda")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
# 2. Load FA3 model
|
||||
try:
|
||||
model_3 = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_3",
|
||||
).to("cuda")
|
||||
except (ValueError, ImportError) as e:
|
||||
pytest.skip(f"Could not load Flash Attention 3 model, skipping test. Error: {e}")
|
||||
|
||||
# 3. Generate with both models
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
output_2 = model_2.generate(
|
||||
**inputs, max_new_tokens=20, do_sample=False, output_scores=True, return_dict_in_generate=True
|
||||
)
|
||||
output_3 = model_3.generate(
|
||||
**inputs, max_new_tokens=20, do_sample=False, output_scores=True, return_dict_in_generate=True
|
||||
)
|
||||
|
||||
# 4. Correctness check
|
||||
# 4a. Logits
|
||||
logits_2 = torch.stack(output_2.scores)
|
||||
logits_3 = torch.stack(output_3.scores)
|
||||
torch.testing.assert_close(logits_2, logits_3, atol=1e-3, rtol=1e-3)
|
||||
logprobs_2 = torch.nn.functional.log_softmax(logits_2, dim=-1)
|
||||
logprobs_3 = torch.nn.functional.log_softmax(logits_3, dim=-1)
|
||||
max_logprob_diff = torch.max(torch.abs(logprobs_2 - logprobs_3)).item()
|
||||
|
||||
# 4b. Generated text
|
||||
text_2 = tokenizer.decode(output_2.sequences[0], skip_special_tokens=True)
|
||||
text_3 = tokenizer.decode(output_3.sequences[0], skip_special_tokens=True)
|
||||
rouge_score = self._calculate_rouge_l([text_2], [text_3])[0]
|
||||
assert rouge_score > 0.99, f"Generated texts do not match (ROUGE-L: {rouge_score})"
|
||||
|
||||
# 5. Performance check
|
||||
with torch.no_grad():
|
||||
time_2 = self._benchmark_generation(model_2, inputs)
|
||||
time_3 = self._benchmark_generation(model_3, inputs)
|
||||
|
||||
print(f"\n--- Flash Attention {2, 3} Parity Test on {model_id} ---")
|
||||
print(f"Prompt: '{prompt}'")
|
||||
print(f"Generated text with Flash Attention 2: {text_2}")
|
||||
print(f"Generated text with Flash Attention 3: {text_3}")
|
||||
print(f"ROUGE-L: {rouge_score}")
|
||||
print(f"Max absolute difference in logprobs: {max_logprob_diff:.5e}")
|
||||
print(f"Flash Attention 2 latency: {time_2:.2f} ms")
|
||||
print(f"Flash Attention 3 latency: {time_3:.2f} ms")
|
||||
print(f"Speed-up: {time_2 / time_3:.2f}x")
|
||||
print("---")
|
||||
@@ -34,6 +34,7 @@ from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
require_accelerate,
|
||||
require_flash_attn,
|
||||
require_flash_attn_3,
|
||||
require_optimum_quanto,
|
||||
require_read_token,
|
||||
require_torch,
|
||||
@@ -2292,6 +2293,7 @@ class GenerationTesterMixin:
|
||||
support_flag = {
|
||||
"sdpa": "_supports_sdpa",
|
||||
"flash_attention_2": "_supports_flash_attn_2",
|
||||
"flash_attention_3": "_supports_flash_attn_3",
|
||||
}
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
@@ -2369,6 +2371,14 @@ class GenerationTesterMixin:
|
||||
"""Tests that generate has equivalent outputs with FA2 and eager attention implementations."""
|
||||
self._test_attention_implementation("flash_attention_2")
|
||||
|
||||
@pytest.mark.flash_attn_3_test
|
||||
@require_flash_attn_3
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_eager_matches_fa3_generate(self):
|
||||
"""Tests that generate has equivalent outputs with FA3 and eager attention implementations."""
|
||||
self._test_attention_implementation("flash_attention_3")
|
||||
|
||||
def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
|
||||
input_batch_size = int(output.sequences.shape[0] / num_return_sequences)
|
||||
internal_batch_size = (
|
||||
|
||||
@@ -84,6 +84,7 @@ from transformers.testing_utils import (
|
||||
require_bitsandbytes,
|
||||
require_deepspeed,
|
||||
require_flash_attn,
|
||||
require_flash_attn_3,
|
||||
require_non_hpu,
|
||||
require_safetensors,
|
||||
require_torch,
|
||||
@@ -3129,18 +3130,19 @@ class ModelTesterMixin:
|
||||
f"{model_class} is too big for the common tests ({num_params})! It should have 1M max."
|
||||
)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_flash_attn_2_inference_equivalence(self):
|
||||
def flash_attn_inference_equivalence(self, attn_implementation: str, padding_side: str):
|
||||
r"""
|
||||
Tests the equivalence between the eager and flash attention implementations.
|
||||
This test is only for inference and runs with `torch_dtype=torch.bfloat16`.
|
||||
"""
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
if (attn_implementation == "flash_attention_2" and not model_class._supports_flash_attn_2) or (
|
||||
attn_implementation == "flash_attention_3" and not model_class._supports_flash_attn_3
|
||||
):
|
||||
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
@@ -3148,7 +3150,7 @@ class ModelTesterMixin:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
@@ -3163,9 +3165,12 @@ class ModelTesterMixin:
|
||||
|
||||
if dummy_attention_mask is not None:
|
||||
dummy_attention_mask = dummy_attention_mask[:1]
|
||||
dummy_attention_mask[:, 1:] = 1
|
||||
dummy_attention_mask[:, :1] = 0
|
||||
|
||||
if padding_side == "left":
|
||||
dummy_attention_mask[:, 1:] = 1
|
||||
dummy_attention_mask[:, :1] = 0
|
||||
else:
|
||||
dummy_attention_mask[:, :-1] = 1
|
||||
dummy_attention_mask[:, -1:] = 0
|
||||
if model.config.is_encoder_decoder:
|
||||
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
|
||||
|
||||
@@ -3220,11 +3225,22 @@ class ModelTesterMixin:
|
||||
else outputs_fa.decoder_hidden_states[-1]
|
||||
)
|
||||
|
||||
assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
|
||||
if padding_side == "left":
|
||||
assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
|
||||
|
||||
# check with inference + dropout
|
||||
model.train()
|
||||
_ = model_fa(dummy_input, **other_inputs)
|
||||
# check with inference + dropout
|
||||
model.train()
|
||||
_ = model_fa(dummy_input, **other_inputs)
|
||||
else:
|
||||
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_flash_attn_2_inference_equivalence(self):
|
||||
self.flash_attn_inference_equivalence(attn_implementation="flash_attention_2", padding_side="left")
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@@ -3232,92 +3248,23 @@ class ModelTesterMixin:
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
self.flash_attn_inference_equivalence(attn_implementation="flash_attention_2", padding_side="right")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
@require_flash_attn_3
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_3_test
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_flash_attn_3_inference_equivalence(self):
|
||||
self.flash_attn_inference_equivalence(attn_implementation="flash_attention_3", padding_side="left")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
|
||||
model.to(torch_device)
|
||||
|
||||
dummy_input = inputs_dict[model.main_input_name][:1]
|
||||
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||
dummy_input = dummy_input.to(torch.bfloat16)
|
||||
|
||||
dummy_attention_mask = inputs_dict.get("attention_mask", None)
|
||||
|
||||
if dummy_attention_mask is not None:
|
||||
dummy_attention_mask = dummy_attention_mask[:1]
|
||||
dummy_attention_mask[:, :-1] = 1
|
||||
dummy_attention_mask[:, -1:] = 0
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
|
||||
|
||||
outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
|
||||
outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
|
||||
else:
|
||||
outputs = model(dummy_input, output_hidden_states=True)
|
||||
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
|
||||
|
||||
logits = (
|
||||
outputs.hidden_states[-1]
|
||||
if not model.config.is_encoder_decoder
|
||||
else outputs.decoder_hidden_states[-1]
|
||||
)
|
||||
logits_fa = (
|
||||
outputs_fa.hidden_states[-1]
|
||||
if not model.config.is_encoder_decoder
|
||||
else outputs_fa.decoder_hidden_states[-1]
|
||||
)
|
||||
|
||||
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
other_inputs = {
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": dummy_attention_mask,
|
||||
"output_hidden_states": True,
|
||||
}
|
||||
if dummy_attention_mask is not None:
|
||||
other_inputs["attention_mask"] = dummy_attention_mask
|
||||
|
||||
outputs = model(dummy_input, **other_inputs)
|
||||
outputs_fa = model_fa(dummy_input, **other_inputs)
|
||||
else:
|
||||
other_inputs = {
|
||||
"output_hidden_states": True,
|
||||
}
|
||||
if dummy_attention_mask is not None:
|
||||
other_inputs["attention_mask"] = dummy_attention_mask
|
||||
|
||||
outputs = model(dummy_input, **other_inputs)
|
||||
outputs_fa = model_fa(dummy_input, **other_inputs)
|
||||
|
||||
logits = (
|
||||
outputs.hidden_states[-1]
|
||||
if not model.config.is_encoder_decoder
|
||||
else outputs.decoder_hidden_states[-1]
|
||||
)
|
||||
logits_fa = (
|
||||
outputs_fa.hidden_states[-1]
|
||||
if not model.config.is_encoder_decoder
|
||||
else outputs_fa.decoder_hidden_states[-1]
|
||||
)
|
||||
|
||||
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
|
||||
@require_flash_attn_3
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_3_test
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_flash_attn_3_inference_equivalence_right_padding(self):
|
||||
self.flash_attn_inference_equivalence(attn_implementation="flash_attention_3", padding_side="right")
|
||||
|
||||
def test_attn_implementation_composite_models(self):
|
||||
"""
|
||||
@@ -3959,24 +3906,21 @@ class ModelTesterMixin:
|
||||
torch.allclose(res_eager[attention_mask == 1], res_sdpa[attention_mask == 1], rtol=1e-4, atol=1e-4)
|
||||
)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
def test_flash_attn_2_can_dispatch_composite_models(self):
|
||||
def flash_attn_can_dispatch_composite_models(self, attn_implementation: str):
|
||||
"""
|
||||
Tests if composite models can dispatch on FA2 if the sub-models support FA2.
|
||||
Tests if composite models can dispatch on flash attention if the sub-models support it.
|
||||
The tests is needed as we handle differently composite models and we cannot check them
|
||||
with above tests. If any of the sub-models does not support FA2, we'll raise an error when dispatching
|
||||
with above tests. If any of the sub-models does not support flash attention, we'll raise an error when dispatching
|
||||
that particular sub-model. Otherwise we dispatch safely in all sub-models, where "sub-models" are specific
|
||||
backbone models (LM/vision/audio/etc)
|
||||
"""
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
if not is_torch_fp16_available_on_device(torch_device):
|
||||
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
|
||||
if not is_torch_bf16_available_on_device(torch_device):
|
||||
self.skipTest(f"bfloat16 not supported on {torch_device} (on the specific device currently used)")
|
||||
|
||||
torch_dtype = torch.float16
|
||||
torch_dtype = torch.bfloat16
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
@@ -3987,44 +3931,64 @@ class ModelTesterMixin:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
|
||||
|
||||
sub_models_supporting_fa2 = [
|
||||
module._supports_flash_attn_2
|
||||
sub_models_supporting_fa = [
|
||||
(
|
||||
module._supports_flash_attn_3
|
||||
if attn_implementation == "flash_attention_3"
|
||||
else module._supports_flash_attn_2
|
||||
)
|
||||
for name, module in model.named_modules()
|
||||
if isinstance(module, PreTrainedModel) and name != ""
|
||||
]
|
||||
supports_fa2_all_modules = (
|
||||
all(sub_models_supporting_fa2)
|
||||
if len(sub_models_supporting_fa2) > 0
|
||||
else model._supports_flash_attn_2
|
||||
supports_fa_all_modules = (
|
||||
all(sub_models_supporting_fa)
|
||||
if len(sub_models_supporting_fa) > 0
|
||||
else (
|
||||
model._supports_flash_attn_3
|
||||
if attn_implementation == "flash_attention_3"
|
||||
else model._supports_flash_attn_2
|
||||
)
|
||||
)
|
||||
if not supports_fa2_all_modules:
|
||||
if not supports_fa_all_modules:
|
||||
with self.assertRaises(ValueError):
|
||||
model_fa2 = model_class.from_pretrained(
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch_dtype,
|
||||
attn_implementation="flash_attention_2",
|
||||
attn_implementation=attn_implementation,
|
||||
)
|
||||
else:
|
||||
model_fa2 = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch_dtype, attn_implementation="flash_attention_2"
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch_dtype, attn_implementation=attn_implementation
|
||||
)
|
||||
for key in model_fa2.config:
|
||||
if isinstance(getattr(model_fa2.config, key), PretrainedConfig):
|
||||
sub_config = getattr(model_fa2.config, key)
|
||||
self.assertTrue(sub_config._attn_implementation == "flash_attention_2")
|
||||
for key in model_fa.config:
|
||||
if isinstance(getattr(model_fa.config, key), PretrainedConfig):
|
||||
sub_config = getattr(model_fa.config, key)
|
||||
self.assertTrue(sub_config._attn_implementation == attn_implementation)
|
||||
|
||||
has_fa2 = False
|
||||
for name, submodule in model_fa2.named_modules():
|
||||
has_fa = False
|
||||
for name, submodule in model_fa.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if (
|
||||
"Attention" in class_name
|
||||
and getattr(submodule, "config", None)
|
||||
and submodule.config._attn_implementation == "flash_attention_2"
|
||||
and submodule.config._attn_implementation == attn_implementation
|
||||
):
|
||||
has_fa2 = True
|
||||
has_fa = True
|
||||
break
|
||||
if not has_fa2:
|
||||
raise ValueError("The FA2 model should have FA2 layers")
|
||||
if not has_fa:
|
||||
raise ValueError(f"The {attn_implementation} model should have {attn_implementation} layers")
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
def test_flash_attn_2_can_dispatch_composite_models(self):
|
||||
self.flash_attn_can_dispatch_composite_models(attn_implementation="flash_attention_2")
|
||||
|
||||
@require_flash_attn_3
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_3_test
|
||||
def test_flash_attn_3_can_dispatch_composite_models(self):
|
||||
self.flash_attn_can_dispatch_composite_models(attn_implementation="flash_attention_3")
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@@ -4121,27 +4085,29 @@ class ModelTesterMixin:
|
||||
|
||||
assert not loss.isnan().any()
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||
def flash_attention_padding_matches_padding_free_with_position_ids(
|
||||
self, attn_implementation: str, fa_kwargs: bool = False
|
||||
):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
max_new_tokens = 30
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
if not (
|
||||
model_class._supports_flash_attn_2
|
||||
if attn_implementation == "flash_attention_2"
|
||||
else model_class._supports_flash_attn_3
|
||||
):
|
||||
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
|
||||
self.skipTest("Model dummy inputs should contain padding in their attention mask")
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
||||
dummy_input = dummy_input.to(torch.float16)
|
||||
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||
dummy_input = dummy_input.to(torch.bfloat16)
|
||||
|
||||
# make sure that all models have enough positions for generation
|
||||
if hasattr(config, "max_position_embeddings"):
|
||||
@@ -4151,7 +4117,7 @@ class ModelTesterMixin:
|
||||
if "position_ids" not in inspect.signature(model.forward).parameters:
|
||||
self.skipTest("Model does not support position_ids")
|
||||
|
||||
if "position_ids" not in inspect.signature(model.forward).parameters:
|
||||
if (not fa_kwargs) and "position_ids" not in inspect.signature(model.forward).parameters:
|
||||
continue # this model doesn't accept position ids as input
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
@@ -4166,26 +4132,40 @@ class ModelTesterMixin:
|
||||
model = (
|
||||
model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation=attn_implementation,
|
||||
)
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
)
|
||||
|
||||
# flatten
|
||||
padfree_inputs_dict = {
|
||||
k: v[dummy_attention_mask.bool()].unsqueeze(0)
|
||||
for k, v in inputs_dict.items()
|
||||
if not k == "attention_mask"
|
||||
}
|
||||
# add position_ids
|
||||
padfree_inputs_dict["position_ids"] = (
|
||||
torch.cat([torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()])
|
||||
.long()
|
||||
.unsqueeze(0)
|
||||
.to(torch_device)
|
||||
)
|
||||
if fa_kwargs:
|
||||
# flatten
|
||||
features = [
|
||||
{"input_ids": i[a.bool()].tolist()}
|
||||
for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"])
|
||||
]
|
||||
|
||||
# add position_ids + fa_kwargs
|
||||
data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True)
|
||||
batch = data_collator(features)
|
||||
padfree_inputs_dict = {
|
||||
k: t.to(torch_device) if torch.is_tensor(t) else t for k, t in batch.items()
|
||||
}
|
||||
else:
|
||||
# flatten
|
||||
padfree_inputs_dict = {
|
||||
k: v[dummy_attention_mask.bool()].unsqueeze(0)
|
||||
for k, v in inputs_dict.items()
|
||||
if not k == "attention_mask"
|
||||
}
|
||||
# add position_ids
|
||||
padfree_inputs_dict["position_ids"] = (
|
||||
torch.cat([torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()])
|
||||
.long()
|
||||
.unsqueeze(0)
|
||||
.to(torch_device)
|
||||
)
|
||||
|
||||
res_padded = model(**inputs_dict)
|
||||
res_padfree = model(**padfree_inputs_dict)
|
||||
@@ -4195,119 +4175,96 @@ class ModelTesterMixin:
|
||||
|
||||
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
|
||||
# acceptable numerical instability
|
||||
tol = torch.finfo(torch.float16).eps
|
||||
tol = torch.finfo(torch.bfloat16).eps
|
||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_2")
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
self.flash_attention_padding_matches_padding_free_with_position_ids(
|
||||
attn_implementation="flash_attention_2", fa_kwargs=True
|
||||
)
|
||||
|
||||
max_new_tokens = 30
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
|
||||
self.skipTest("Model dummy inputs should contain padding in their attention mask")
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
||||
dummy_input = dummy_input.to(torch.float16)
|
||||
|
||||
# make sure that all models have enough positions for generation
|
||||
if hasattr(config, "max_position_embeddings"):
|
||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
||||
|
||||
model = model_class(config)
|
||||
if "position_ids" not in inspect.signature(model.forward).parameters:
|
||||
self.skipTest("Model does not support position_ids")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
# ensure left padding, to adapt for some models
|
||||
if 0 in inputs_dict["attention_mask"][:, -1]:
|
||||
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
|
||||
dummy_attention_mask = inputs_dict["attention_mask"]
|
||||
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
|
||||
|
||||
model = (
|
||||
model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
)
|
||||
|
||||
# flatten
|
||||
features = [
|
||||
{"input_ids": i[a.bool()].tolist()}
|
||||
for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"])
|
||||
]
|
||||
|
||||
# add position_ids + fa_kwargs
|
||||
data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True)
|
||||
batch = data_collator(features)
|
||||
batch_accelerator = {k: t.to(torch_device) if torch.is_tensor(t) else t for k, t in batch.items()}
|
||||
|
||||
res_padded = model(**inputs_dict)
|
||||
res_padfree = model(**batch_accelerator)
|
||||
|
||||
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||
logits_padfree = res_padfree.logits[0]
|
||||
|
||||
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
|
||||
# acceptable numerical instability
|
||||
tol = torch.finfo(torch.float16).eps
|
||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||
|
||||
@require_flash_attn
|
||||
@require_flash_attn_3
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@mark.flash_attn_3_test
|
||||
@slow
|
||||
def test_flash_attn_2_from_config(self):
|
||||
def test_flash_attention_3_padding_matches_padding_free_with_position_ids(self):
|
||||
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_3")
|
||||
|
||||
@require_flash_attn_3
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_3_test
|
||||
@slow
|
||||
def test_flash_attention_3_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
|
||||
self.flash_attention_padding_matches_padding_free_with_position_ids(
|
||||
attn_implementation="flash_attention_3", fa_kwargs=True
|
||||
)
|
||||
|
||||
def flash_attn_from_config(self, attn_implementation: str):
|
||||
r"""
|
||||
Tests if the model can be loaded with `attn_implementation` from the config and if the
|
||||
weights are not randomly initialized.
|
||||
"""
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
if (attn_implementation == "flash_attention_2" and not model_class._supports_flash_attn_2) or (
|
||||
attn_implementation == "flash_attention_3" and not model_class._supports_flash_attn_3
|
||||
):
|
||||
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
# TODO: to change it in the future with other relevant auto classes
|
||||
fa2_model = model_class._from_config(
|
||||
config, attn_implementation="flash_attention_2", torch_dtype=torch.float16
|
||||
fa_model = model_class._from_config(
|
||||
config, attn_implementation=attn_implementation, torch_dtype=torch.bfloat16
|
||||
).to(torch_device)
|
||||
|
||||
dummy_input = inputs_dict[fa2_model.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
||||
dummy_input = dummy_input.to(torch.float16)
|
||||
dummy_input = inputs_dict[fa_model.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||
dummy_input = dummy_input.to(torch.bfloat16)
|
||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
||||
|
||||
if fa2_model.config.is_encoder_decoder:
|
||||
if fa_model.config.is_encoder_decoder:
|
||||
dummy_decoder_input_ids = inputs_dict["decoder_input_ids"]
|
||||
dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"]
|
||||
_ = fa2_model(
|
||||
_ = fa_model(
|
||||
dummy_input,
|
||||
attention_mask=dummy_attention_mask,
|
||||
decoder_input_ids=dummy_decoder_input_ids,
|
||||
decoder_attention_mask=dummy_decoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
_ = fa2_model(dummy_input, attention_mask=dummy_attention_mask)
|
||||
_ = fa_model(dummy_input, attention_mask=dummy_attention_mask)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fa2_model.save_pretrained(tmpdirname)
|
||||
fa_model.save_pretrained(tmpdirname)
|
||||
model_from_pretrained = model_class.from_pretrained(tmpdirname)
|
||||
self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_2")
|
||||
self.assertTrue(model_from_pretrained.config._attn_implementation != attn_implementation)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_from_config(self):
|
||||
self.flash_attn_from_config(attn_implementation="flash_attention_2")
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_3_test
|
||||
@slow
|
||||
def test_flash_attn_3_from_config(self):
|
||||
self.flash_attn_from_config(attn_implementation="flash_attention_3")
|
||||
|
||||
def _get_custom_4d_mask_test_data(self):
|
||||
# Sequence in which all but the last token is the same
|
||||
|
||||
@@ -77,6 +77,7 @@ from transformers.utils import (
|
||||
)
|
||||
from transformers.utils.import_utils import (
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_3_available,
|
||||
is_flax_available,
|
||||
is_tf_available,
|
||||
is_torch_npu_available,
|
||||
@@ -676,6 +677,9 @@ class ModelUtilsTest(TestCasePlus):
|
||||
if is_flash_attn_available():
|
||||
attn_implementation_available.append("flash_attention_2")
|
||||
|
||||
if is_flash_attn_3_available():
|
||||
attn_implementation_available.append("flash_attention_3")
|
||||
|
||||
for requested_attn_implementation in attn_implementation_available:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
TINY_MISTRAL, attn_implementation=requested_attn_implementation
|
||||
@@ -700,6 +704,9 @@ class ModelUtilsTest(TestCasePlus):
|
||||
if is_flash_attn_available():
|
||||
attn_implementation_available.append("flash_attention_2")
|
||||
|
||||
if is_flash_attn_3_available():
|
||||
attn_implementation_available.append("flash_attention_3")
|
||||
|
||||
for requested_attn_implementation in attn_implementation_available:
|
||||
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation)
|
||||
# Ensure the config was set correctly
|
||||
|
||||
Reference in New Issue
Block a user