From acab997befc352229fcd99a0eb89884632d7412b Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Thu, 18 Apr 2024 14:09:52 +0200 Subject: [PATCH] Revert "Re-enable SDPA's FA2 path (#30070)" (#30314) * Revert "Re-enable SDPA's FA2 path (#30070)" This reverts commit 05bdef16b611df0946a6a602503f1ace604b6c80. * Revert "Fix quality Olmo + SDPA (#30302)" This reverts commit ec92f983af5295fc92414a37b988d8384785988a. --- src/transformers/modeling_attn_mask_utils.py | 91 +++++++------------ .../models/cohere/modeling_cohere.py | 36 ++------ .../models/gemma/modeling_gemma.py | 36 ++------ .../models/llama/modeling_llama.py | 37 ++------ src/transformers/models/olmo/modeling_olmo.py | 34 ++----- tests/test_modeling_common.py | 36 -------- 6 files changed, 69 insertions(+), 201 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 8ae9b57b6c..43da8917b2 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -234,59 +234,6 @@ class AttentionMaskConverter: return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True)) - @staticmethod - def _ignore_causal_mask_sdpa( - attention_mask: Optional[torch.Tensor], - inputs_embeds: torch.Tensor, - past_key_values_length: int, - sliding_window: Optional[int] = None, - ) -> bool: - """ - Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument. - - In case no token is masked in the `attention_mask` argument, if `query_length == 1` or - `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks, - allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). - """ - - batch_size, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1] - key_value_length = query_length + past_key_values_length - - is_tracing = ( - torch.jit.is_tracing() - or isinstance(inputs_embeds, torch.fx.Proxy) - or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) - ) - - ignore_causal_mask = False - - if attention_mask is None: - # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or - # or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). - # Thus, we currently can NOT set `ignore_causal_mask = True` here. We would need a `torch._dynamo.is_exporting()` flag. - # - # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` (`TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor`). - if sliding_window is None or key_value_length < sliding_window: - ignore_causal_mask = not is_tracing - elif sliding_window is None or key_value_length < sliding_window: - if len(attention_mask.shape) == 4: - expected_shape = (batch_size, 1, query_length, key_value_length) - if tuple(attention_mask.shape) != expected_shape: - raise ValueError( - f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." - ) - elif not is_tracing and torch.all(attention_mask == 1): - if query_length == 1 or key_value_length == query_length: - # For query_length == 1, causal attention and bi-directional attention are the same. - ignore_causal_mask = True - - # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation - # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. - # Reference: https://github.com/pytorch/pytorch/issues/108108 - # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3. - - return ignore_causal_mask - def _prepare_4d_causal_attention_mask( attention_mask: Optional[torch.Tensor], @@ -358,6 +305,7 @@ def _prepare_4d_causal_attention_mask_for_sdpa( attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) key_value_length = input_shape[-1] + past_key_values_length + _, query_length = input_shape # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. @@ -368,12 +316,37 @@ def _prepare_4d_causal_attention_mask_for_sdpa( or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) ) - ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - sliding_window=sliding_window, - ) + ignore_causal_mask = False + + if attention_mask is None: + if sliding_window is None or key_value_length < sliding_window: + ignore_causal_mask = not is_tracing + elif sliding_window is None or key_value_length < sliding_window: + # 4d mask is passed through + if len(attention_mask.shape) == 4: + expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) + if tuple(attention_mask.shape) != expected_shape: + raise ValueError( + f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." + ) + else: + # if the 4D mask has correct shape - invert it and fill with negative infinity + inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype) + attention_mask = inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min + ) + return attention_mask + + elif not is_tracing and torch.all(attention_mask == 1): + if query_length == 1: + # For query_length == 1, causal attention and bi-directional attention are the same. + ignore_causal_mask = True + elif key_value_length == query_length: + ignore_causal_mask = True + + # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation + # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. + # Reference: https://github.com/pytorch/pytorch/issues/108108 if ignore_causal_mask: expanded_4d_mask = None diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 950d45ea86..95a7d76827 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -590,15 +590,12 @@ class CohereSdpaAttention(CohereAttention): key_states = key_states.contiguous() value_states = value_states.contiguous() - # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather - # relying on the `is_causal` argument. attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=causal_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -911,7 +908,9 @@ class CohereModel(CoherePreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1] + ) # embed positions hidden_states = inputs_embeds @@ -975,31 +974,16 @@ class CohereModel(CoherePreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_seen_tokens: int, - ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None - if self.config._attn_implementation == "sdpa": - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, - # in order to dispatch on Flash Attention 2. - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens - ): - return None - dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] @@ -1007,9 +991,7 @@ class CohereModel(CoherePreTrainedModel): target_length = self.config.max_position_embeddings else: # dynamic cache target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 + attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1 ) causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 6077259d0b..c8b9b11c55 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -570,15 +570,12 @@ class GemmaSdpaAttention(GemmaAttention): key_states = key_states.contiguous() value_states = value_states.contiguous() - # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather - # relying on the `is_causal` argument. attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=causal_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -891,7 +888,9 @@ class GemmaModel(GemmaPreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1] + ) # embed positions hidden_states = inputs_embeds @@ -961,31 +960,16 @@ class GemmaModel(GemmaPreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_seen_tokens: int, - ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None - if self.config._attn_implementation == "sdpa": - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, - # in order to dispatch on Flash Attention 2. - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens - ): - return None - dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] @@ -993,9 +977,7 @@ class GemmaModel(GemmaPreTrainedModel): target_length = self.config.max_position_embeddings else: # dynamic cache target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 + attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1 ) causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 2b8e8f6d09..e1afb61be0 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -656,6 +656,7 @@ class LlamaSdpaAttention(LlamaAttention): value_states = repeat_kv(value_states, self.num_key_value_groups) causal_mask = attention_mask + # if attention_mask is not None and cache_position is not None: if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] @@ -666,15 +667,12 @@ class LlamaSdpaAttention(LlamaAttention): key_states = key_states.contiguous() value_states = value_states.contiguous() - # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather - # relying on the `is_causal` argument. attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=causal_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -989,7 +987,9 @@ class LlamaModel(LlamaPreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1] + ) # embed positions hidden_states = inputs_embeds @@ -1053,31 +1053,16 @@ class LlamaModel(LlamaPreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_seen_tokens: int, - ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None - if self.config._attn_implementation == "sdpa": - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, - # in order to dispatch on Flash Attention 2. - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens - ): - return None - dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] @@ -1085,9 +1070,7 @@ class LlamaModel(LlamaPreTrainedModel): target_length = self.config.max_position_embeddings else: # dynamic cache target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 + attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1 ) causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 83637536a1..b8fb01d7b2 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -653,7 +653,6 @@ class OlmoSdpaAttention(OlmoAttention): value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=causal_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -971,7 +970,9 @@ class OlmoModel(OlmoPreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1] + ) # embed positions hidden_states = inputs_embeds @@ -1035,32 +1036,17 @@ class OlmoModel(OlmoPreTrainedModel): attentions=all_self_attns, ) + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_seen_tokens: int, - ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - + def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None - if self.config._attn_implementation == "sdpa": - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, - # in order to dispatch on Flash Attention 2. - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens - ): - return None - dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] @@ -1068,9 +1054,7 @@ class OlmoModel(OlmoPreTrainedModel): target_length = self.config.max_position_embeddings else: # dynamic cache target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 + attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1 ) causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 71cb28d754..a3cbcc0818 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3772,42 +3772,6 @@ class ModelTesterMixin: self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases)) - @require_torch_sdpa - @require_torch_gpu - @slow - def test_sdpa_can_dispatch_on_flash(self): - compute_capability = torch.cuda.get_device_capability() - major, _ = compute_capability - - if not torch.version.cuda or major < 8: - self.skipTest("This test requires an NVIDIA GPU with compute capability >= 8.0") - - for model_class in self.all_model_classes: - if not model_class._supports_sdpa: - self.skipTest(f"{model_class.__name__} does not support SDPA") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - if config.model_type in ["llava", "llava_next", "vipllava"]: - self.skipTest("Llava-like models currently (transformers==4.39.1) requires an attention_mask input") - if config.model_type in ["idefics"]: - self.skipTest("Idefics currently (transformers==4.39.1) requires an image_attention_mask input") - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa") - model.to(torch_device) - - inputs_dict.pop("attention_mask", None) - inputs_dict.pop("decoder_attention_mask", None) - - for name, inp in inputs_dict.items(): - if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]: - inputs_dict[name] = inp.to(torch.float16) - - with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): - _ = model(**inputs_dict) - @require_torch_sdpa @slow def test_eager_matches_sdpa_generate(self):