Fix SDPA sliding window compatibility (#30127)
* fix sdpa + sliding window * give credit Co-authored-by: ehuaa <ehuamail@163.com> * remove unnecessary warning * fix typog * add test --------- Co-authored-by: ehuaa <ehuamail@163.com>
This commit is contained in:
@@ -305,7 +305,7 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
|
|||||||
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
|
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
|
||||||
|
|
||||||
key_value_length = input_shape[-1] + past_key_values_length
|
key_value_length = input_shape[-1] + past_key_values_length
|
||||||
batch_size, query_length = input_shape
|
_, query_length = input_shape
|
||||||
|
|
||||||
# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
|
# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
|
||||||
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
|
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
|
||||||
@@ -316,7 +316,12 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
|
|||||||
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
||||||
)
|
)
|
||||||
|
|
||||||
if attention_mask is not None:
|
ignore_causal_mask = False
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
if sliding_window is None or key_value_length < sliding_window:
|
||||||
|
ignore_causal_mask = not is_tracing
|
||||||
|
elif sliding_window is None or key_value_length < sliding_window:
|
||||||
# 4d mask is passed through
|
# 4d mask is passed through
|
||||||
if len(attention_mask.shape) == 4:
|
if len(attention_mask.shape) == 4:
|
||||||
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
|
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
|
||||||
@@ -335,26 +340,17 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
|
|||||||
elif not is_tracing and torch.all(attention_mask == 1):
|
elif not is_tracing and torch.all(attention_mask == 1):
|
||||||
if query_length == 1:
|
if query_length == 1:
|
||||||
# For query_length == 1, causal attention and bi-directional attention are the same.
|
# For query_length == 1, causal attention and bi-directional attention are the same.
|
||||||
attention_mask = None
|
ignore_causal_mask = True
|
||||||
elif key_value_length == query_length:
|
elif key_value_length == query_length:
|
||||||
attention_mask = None
|
ignore_causal_mask = True
|
||||||
else:
|
|
||||||
# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
|
|
||||||
# may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
|
|
||||||
# Reference: https://github.com/pytorch/pytorch/issues/108108
|
|
||||||
pass
|
|
||||||
elif query_length > 1 and key_value_length != query_length:
|
|
||||||
# See the comment above (https://github.com/pytorch/pytorch/issues/108108).
|
|
||||||
# Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`.
|
|
||||||
attention_mask = True
|
|
||||||
elif is_tracing:
|
|
||||||
raise ValueError(
|
|
||||||
'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.'
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is None:
|
# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
|
||||||
|
# may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
|
||||||
|
# Reference: https://github.com/pytorch/pytorch/issues/108108
|
||||||
|
|
||||||
|
if ignore_causal_mask:
|
||||||
expanded_4d_mask = None
|
expanded_4d_mask = None
|
||||||
elif attention_mask is True:
|
elif attention_mask is None:
|
||||||
expanded_4d_mask = attn_mask_converter.to_causal_4d(
|
expanded_4d_mask = attn_mask_converter.to_causal_4d(
|
||||||
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1006,6 +1006,7 @@ class MistralModel(MistralPreTrainedModel):
|
|||||||
(batch_size, seq_length),
|
(batch_size, seq_length),
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
past_key_values_length,
|
past_key_values_length,
|
||||||
|
sliding_window=self.config.sliding_window,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 4d mask is passed through the layers
|
# 4d mask is passed through the layers
|
||||||
|
|||||||
@@ -1191,6 +1191,7 @@ class MixtralModel(MixtralPreTrainedModel):
|
|||||||
(batch_size, seq_length),
|
(batch_size, seq_length),
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
past_key_values_length,
|
past_key_values_length,
|
||||||
|
sliding_window=self.config.sliding_window,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 4d mask is passed through the layers
|
# 4d mask is passed through the layers
|
||||||
|
|||||||
@@ -1017,6 +1017,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|||||||
(batch_size, seq_length),
|
(batch_size, seq_length),
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
past_key_values_length,
|
past_key_values_length,
|
||||||
|
sliding_window=self.config.sliding_window,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 4d mask is passed through the layers
|
# 4d mask is passed through the layers
|
||||||
|
|||||||
@@ -1183,6 +1183,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
|
|||||||
(batch_size, seq_length),
|
(batch_size, seq_length),
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
past_key_values_length,
|
past_key_values_length,
|
||||||
|
sliding_window=self.config.sliding_window,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 4d mask is passed through the layers
|
# 4d mask is passed through the layers
|
||||||
|
|||||||
@@ -995,6 +995,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
|
|||||||
(batch_size, seq_length),
|
(batch_size, seq_length),
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
past_key_values_length,
|
past_key_values_length,
|
||||||
|
sliding_window=self.config.sliding_window,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 4d mask is passed through the layers
|
# 4d mask is passed through the layers
|
||||||
|
|||||||
@@ -3841,6 +3841,57 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
self.assertTrue(torch.allclose(res_eager, res_sdpa))
|
self.assertTrue(torch.allclose(res_eager, res_sdpa))
|
||||||
|
|
||||||
|
@require_torch_sdpa
|
||||||
|
def test_sdpa_matches_eager_sliding_window(self):
|
||||||
|
WINDOW_ATTENTION_MODELS = ["mistral", "mixtral", "qwen2", "qwen_moe", "starcoder2"]
|
||||||
|
|
||||||
|
if len(self.all_generative_model_classes) == 0:
|
||||||
|
self.skipTest(f"No generative model classes for {self.__class__.__name__}")
|
||||||
|
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
if config.model_type not in WINDOW_ATTENTION_MODELS:
|
||||||
|
self.skipTest(f"{config.model_type} does not use window attention")
|
||||||
|
|
||||||
|
config.sliding_window = 2
|
||||||
|
|
||||||
|
dummy_input = inputs_dict[model_class.main_input_name]
|
||||||
|
attention_mask = inputs_dict["attention_mask"]
|
||||||
|
|
||||||
|
self.assertTrue(dummy_input.ndim == 2)
|
||||||
|
self.assertTrue(dummy_input.shape[1] > 6)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
with torch.device(torch_device):
|
||||||
|
model_eager = AutoModelForCausalLM.from_config(
|
||||||
|
config, attn_implementation="eager", torch_dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
model_eager.save_pretrained(tmpdir)
|
||||||
|
|
||||||
|
with torch.device(torch_device):
|
||||||
|
model_sdpa = AutoModelForCausalLM.from_pretrained(
|
||||||
|
tmpdir, attn_implementation="sdpa", torch_dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
model_eager = model_eager.eval()
|
||||||
|
model_sdpa = model_sdpa.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
with torch.backends.cuda.sdp_kernel(
|
||||||
|
enable_flash=False,
|
||||||
|
enable_math=True,
|
||||||
|
enable_mem_efficient=False,
|
||||||
|
):
|
||||||
|
res_eager = model_eager(**inputs_dict, return_dict=False)[0]
|
||||||
|
res_sdpa = model_sdpa(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
# Only non-padding tokens are expected to match.
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(res_eager[attention_mask == 1], res_sdpa[attention_mask == 1], rtol=1e-3)
|
||||||
|
)
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@mark.flash_attn_test
|
@mark.flash_attn_test
|
||||||
|
|||||||
Reference in New Issue
Block a user