Llama: make slow tests green 🟢 (#33138)

This commit is contained in:
Joao Gante
2024-08-27 14:44:42 +01:00
committed by GitHub
parent 9956c2bc98
commit c6b23fda65
31 changed files with 39 additions and 180 deletions

View File

@@ -16,6 +16,8 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from .utils.import_utils import is_torchdynamo_compiling
@dataclass @dataclass
class AttentionMaskConverter: class AttentionMaskConverter:
@@ -243,30 +245,33 @@ class AttentionMaskConverter:
is_training: bool = False, is_training: bool = False,
) -> bool: ) -> bool:
""" """
Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument. Detects whether the optional user-specified attention_mask & the automatically created causal mask can be
ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
In case no token is masked in the `attention_mask` argument, if `query_length == 1` or In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
`key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks, `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
passed).
""" """
_, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1] _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
key_value_length = query_length + past_key_values_length key_value_length = query_length + past_key_values_length
is_tracing = ( is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling()
torch.jit.is_tracing()
or isinstance(inputs_embeds, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
ignore_causal_mask = False ignore_causal_mask = False
if attention_mask is None: if attention_mask is None:
# TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input
# or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). # shape, thus SDPA's `is_causal` argument is rightfully updated
# (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using
# `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
# hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True`
# which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
# Thus, we only set `ignore_causal_mask = True` if the model is set to training. # Thus, we only set `ignore_causal_mask = True` if the model is set to training.
# #
# Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor"). # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal`
# ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor").
if ( if (
(is_training or not is_tracing) (is_training or not is_tracing)
and (query_length == 1 or key_value_length == query_length) and (query_length == 1 or key_value_length == query_length)
@@ -281,8 +286,9 @@ class AttentionMaskConverter:
# For query_length == 1, causal attention and bi-directional attention are the same. # For query_length == 1, causal attention and bi-directional attention are the same.
ignore_causal_mask = True ignore_causal_mask = True
# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore
# may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. # the attention mask, as SDPA causal mask generation may be wrong. We will set `is_causal=False` in
# SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
# Reference: https://github.com/pytorch/pytorch/issues/108108 # Reference: https://github.com/pytorch/pytorch/issues/108108
# TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3. # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.
@@ -363,11 +369,7 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing = ( is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling()
torch.jit.is_tracing()
or isinstance(inputs_embeds, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa( ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask=attention_mask, attention_mask=attention_mask,
@@ -439,11 +441,7 @@ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype,
_, key_value_length = mask.shape _, key_value_length = mask.shape
tgt_len = tgt_len if tgt_len is not None else key_value_length tgt_len = tgt_len if tgt_len is not None else key_value_length
is_tracing = ( is_tracing = torch.jit.is_tracing() or isinstance(mask, torch.fx.Proxy) or is_torchdynamo_compiling()
torch.jit.is_tracing()
or isinstance(mask, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows. # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows.
if not is_tracing and torch.all(mask == 1): if not is_tracing and torch.all(mask == 1):

View File

@@ -790,11 +790,6 @@ class BloomModel(BloomPreTrainedModel):
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask

View File

@@ -1433,11 +1433,6 @@ class ChameleonModel(ChameleonPreTrainedModel):
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask

View File

@@ -632,11 +632,6 @@ class CodeGenModel(CodeGenPreTrainedModel):
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask

View File

@@ -912,11 +912,6 @@ class CohereModel(CoherePreTrainedModel):
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask

View File

@@ -1163,11 +1163,6 @@ class DbrxModel(DbrxPreTrainedModel):
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask

View File

@@ -931,11 +931,6 @@ class GemmaModel(GemmaPreTrainedModel):
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask

View File

@@ -846,11 +846,6 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask

View File

@@ -1017,11 +1017,6 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask

View File

@@ -941,11 +941,6 @@ class GPTJModel(GPTJPreTrainedModel):
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask

View File

@@ -1474,11 +1474,6 @@ class IdeficsModel(IdeficsPreTrainedModel):
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask

View File

@@ -1136,11 +1136,6 @@ class JetMoeModel(JetMoePreTrainedModel):
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask

View File

@@ -1038,11 +1038,6 @@ class LlamaModel(LlamaPreTrainedModel):
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask

View File

@@ -37,6 +37,7 @@ from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_2_1 from ...pytorch_utils import is_torch_greater_or_equal_than_2_1
from ...utils import is_accelerate_available, logging from ...utils import is_accelerate_available, logging
from ...utils.backbone_utils import load_backbone from ...utils.backbone_utils import load_backbone
from ...utils.import_utils import is_torchdynamo_compiling
from .configuration_mask2former import Mask2FormerConfig from .configuration_mask2former import Mask2FormerConfig
@@ -1999,11 +2000,7 @@ class Mask2FormerMaskPredictor(nn.Module):
def forward(self, outputs: torch.Tensor, pixel_embeddings: torch.Tensor, attention_mask_target_size: int = None): def forward(self, outputs: torch.Tensor, pixel_embeddings: torch.Tensor, attention_mask_target_size: int = None):
mask_embeddings = self.mask_embedder(outputs.transpose(0, 1)) mask_embeddings = self.mask_embedder(outputs.transpose(0, 1))
is_tracing = ( is_tracing = torch.jit.is_tracing() or isinstance(outputs, torch.fx.Proxy) or is_torchdynamo_compiling()
torch.jit.is_tracing()
or isinstance(outputs, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
# Sum up over the channels # Sum up over the channels
if is_tracing and not is_torch_greater_or_equal_than_2_1: if is_tracing and not is_torch_greater_or_equal_than_2_1:
# Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly # Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly

View File

@@ -39,6 +39,7 @@ from ...utils import (
requires_backends, requires_backends,
) )
from ...utils.backbone_utils import load_backbone from ...utils.backbone_utils import load_backbone
from ...utils.import_utils import is_torchdynamo_compiling
from ..detr import DetrConfig from ..detr import DetrConfig
from .configuration_maskformer import MaskFormerConfig from .configuration_maskformer import MaskFormerConfig
from .configuration_maskformer_swin import MaskFormerSwinConfig from .configuration_maskformer_swin import MaskFormerSwinConfig
@@ -1680,11 +1681,7 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
# get the auxiliary predictions (one for each decoder's layer) # get the auxiliary predictions (one for each decoder's layer)
auxiliary_logits: List[str, Tensor] = [] auxiliary_logits: List[str, Tensor] = []
is_tracing = ( is_tracing = torch.jit.is_tracing() or isinstance(outputs, torch.fx.Proxy) or is_torchdynamo_compiling()
torch.jit.is_tracing()
or isinstance(outputs, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
# This code is a little bit cumbersome, an improvement can be to return a list of predictions. If we have auxiliary loss then we are going to return more than one element in the list # This code is a little bit cumbersome, an improvement can be to return a list of predictions. If we have auxiliary loss then we are going to return more than one element in the list
if self.config.use_auxiliary_loss: if self.config.use_auxiliary_loss:
stacked_transformer_decoder_outputs = torch.stack(outputs.transformer_decoder_hidden_states) stacked_transformer_decoder_outputs = torch.stack(outputs.transformer_decoder_hidden_states)

View File

@@ -852,11 +852,6 @@ class MistralModel(MistralPreTrainedModel):
use_cache: bool, use_cache: bool,
output_attentions: bool, output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self._attn_implementation == "flash_attention_2": if self._attn_implementation == "flash_attention_2":
if attention_mask is not None and use_cache: if attention_mask is not None and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]

View File

@@ -1121,11 +1121,6 @@ class MixtralModel(MixtralPreTrainedModel):
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask

View File

@@ -919,11 +919,6 @@ class NemotronModel(NemotronPreTrainedModel):
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask

View File

@@ -958,11 +958,6 @@ class OlmoModel(OlmoPreTrainedModel):
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask

View File

@@ -769,11 +769,6 @@ class PersimmonModel(PersimmonPreTrainedModel):
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask

View File

@@ -1054,11 +1054,6 @@ class PhiModel(PhiPreTrainedModel):
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask

View File

@@ -1094,11 +1094,6 @@ class Phi3Model(Phi3PreTrainedModel):
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask

View File

@@ -959,11 +959,6 @@ class Qwen2Model(Qwen2PreTrainedModel):
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask

View File

@@ -1132,11 +1132,6 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask

View File

@@ -1171,11 +1171,6 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask

View File

@@ -34,6 +34,7 @@ from ...utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from ...utils.import_utils import is_torchdynamo_compiling
from .configuration_recurrent_gemma import RecurrentGemmaConfig from .configuration_recurrent_gemma import RecurrentGemmaConfig
@@ -329,9 +330,7 @@ class RecurrentGemmaRglru(nn.Module):
# Apply gamma normalization to the input. We need to clip the derivatives of # Apply gamma normalization to the input. We need to clip the derivatives of
# `sqrt` in order to prevent NaNs during training in bfloat16. TODO a bit annoying # `sqrt` in order to prevent NaNs during training in bfloat16. TODO a bit annoying
multiplier = 1 multiplier = 1
tracing = isinstance(activations, torch.fx.Proxy) or ( tracing = isinstance(activations, torch.fx.Proxy) or is_torchdynamo_compiling()
hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()
)
if not torch.jit.is_tracing() and not tracing: if not torch.jit.is_tracing() and not tracing:
multiplier = SqrtBoundDerivative.apply(1 - a_square) multiplier = SqrtBoundDerivative.apply(1 - a_square)
multiplier = reset + ~reset * multiplier multiplier = reset + ~reset * multiplier
@@ -747,10 +746,6 @@ class RecurrentGemmaModel(RecurrentGemmaPreTrainedModel):
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
) )
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
# Ignore copy # Ignore copy
def _update_causal_mask(self, attention_mask, input_tensor, cache_position): def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
dtype, device = input_tensor.dtype, input_tensor.device dtype, device = input_tensor.dtype, input_tensor.device

View File

@@ -1046,11 +1046,6 @@ class StableLmModel(StableLmPreTrainedModel):
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask

View File

@@ -933,11 +933,6 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask

View File

@@ -1428,11 +1428,6 @@ class WhisperDecoder(WhisperPreTrainedModel):
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask

View File

@@ -1803,6 +1803,8 @@ class GenerationTesterMixin:
self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported") self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported")
model = model_class(config).to(torch_device) model = model_class(config).to(torch_device)
model.eval() # otherwise `self.training` is `True` -- this flag is used at attn mask creation time
input_ids = inputs_dict["input_ids"].to(torch_device) input_ids = inputs_dict["input_ids"].to(torch_device)
# creates two sets of *different* inputs with the same shape # creates two sets of *different* inputs with the same shape
half_batch_size = input_ids.shape[0] // 2 half_batch_size = input_ids.shape[0] // 2
@@ -1815,22 +1817,14 @@ class GenerationTesterMixin:
} }
for model_inputs in input_ids_sets: for model_inputs in input_ids_sets:
# dynamic cache # eager dynamic cache
output_dynamic = model.generate(model_inputs, **generation_kwargs) output_dynamic = model.generate(model_inputs, **generation_kwargs)
# eager static cache # end-to-end compiled dynamic cache
torch.compiler.reset()
model.generation_config.cache_implementation = "static"
output_static = model.generate(model_inputs, **generation_kwargs)
self.assertListEqual(output_dynamic.tolist(), output_static.tolist())
# compiled static cache (removes the cache initialized in the previous check, to confirm we can
# initialize the cache in full compiled mode)
model._cache = None
torch.compiler.reset() torch.compiler.reset()
compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
generation_config = copy.deepcopy(model.generation_config) generation_config = copy.deepcopy(model.generation_config)
generation_config.update(**generation_kwargs) generation_config.update(**generation_kwargs)
compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
output_compiled = compiled_generate(model_inputs, generation_config=generation_config) output_compiled = compiled_generate(model_inputs, generation_config=generation_config)
self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist()) self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist())

View File

@@ -726,8 +726,10 @@ class LlamaIntegrationTest(unittest.TestCase):
An integration test for llama 3.1. It tests against a long output to ensure the subtle numerical differences An integration test for llama 3.1. It tests against a long output to ensure the subtle numerical differences
from llama 3.1.'s RoPE can be detected from llama 3.1.'s RoPE can be detected
""" """
# diff on `EXPECTED_TEXT`:
# 2024-08-26: updating from torch 2.3.1 to 2.4.0 slightly changes the results.
EXPECTED_TEXT = ( EXPECTED_TEXT = (
"Tell me about the french revolution. The french revolution was a period of radical social and political " "Tell me about the french revolution. The french revolution was a period of radical political and social "
"upheaval in France that lasted from 1789 until 1799. It was a time of great change and upheaval, marked " "upheaval in France that lasted from 1789 until 1799. It was a time of great change and upheaval, marked "
"by the overthrow of the monarchy, the rise of the middle class, and the eventual establishment of the " "by the overthrow of the monarchy, the rise of the middle class, and the eventual establishment of the "
"First French Republic.\nThe revolution began in 1789 with the Estates-General, a representative " "First French Republic.\nThe revolution began in 1789 with the Estates-General, a representative "
@@ -779,8 +781,8 @@ class LlamaIntegrationTest(unittest.TestCase):
torch.allclose( torch.allclose(
EXPECTED_SLICE[self.cuda_compute_capability_major_version].to(torch_device), EXPECTED_SLICE[self.cuda_compute_capability_major_version].to(torch_device),
out.logits[0, 0, :15], out.logits[0, 0, :15],
atol=1e-3, atol=1e-2,
rtol=1e-3, rtol=1e-2,
) )
) )
@@ -816,8 +818,8 @@ class LlamaIntegrationTest(unittest.TestCase):
torch.allclose( torch.allclose(
EXPECTED_SLICE[self.cuda_compute_capability_major_version].to(torch_device), EXPECTED_SLICE[self.cuda_compute_capability_major_version].to(torch_device),
out.logits[0, 0, :15], out.logits[0, 0, :15],
atol=1e-3, atol=1e-2,
rtol=1e-3, rtol=1e-2,
) )
) )
@@ -887,6 +889,7 @@ class LlamaIntegrationTest(unittest.TestCase):
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text)
# Static Cache + compile # Static Cache + compile
model._cache = None # clear cache object, initialized when we pass `cache_implementation="static"`
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
generated_ids = model.generate( generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"