From a6adc05e6b10479474ecdb1dbd3f9c6925d5e332 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 17 Jan 2024 11:50:53 +0100 Subject: [PATCH] symbolic_trace: add past_key_values, llama, sdpa support (#28447) * torch.fx: add pkv, llama, sdpa support * Update src/transformers/models/opt/modeling_opt.py * remove spaces * trigger ci * use explicit variable names --- src/transformers/modeling_attn_mask_utils.py | 17 ++++++----- src/transformers/utils/fx.py | 29 ++++++++++++++++++ tests/models/llama/test_modeling_llama.py | 1 + tests/test_modeling_common.py | 32 +++++++++++++++++--- 4 files changed, 68 insertions(+), 11 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index f0964f9402..67555239c7 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -132,6 +132,7 @@ class AttentionMaskConverter: expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( attention_mask_2d.device ) + if causal_4d_mask is not None: expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min) @@ -346,10 +347,10 @@ def _prepare_4d_causal_attention_mask_for_sdpa( key_value_length = input_shape[-1] + past_key_values_length batch_size, query_length = input_shape - # torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` + # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. # TODO: Fix this as well when using torchdynamo with fullgraph=True. - is_tracing = torch.jit.is_tracing() + is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) if attention_mask is not None: # 4d mask is passed through @@ -367,10 +368,8 @@ def _prepare_4d_causal_attention_mask_for_sdpa( ) return attention_mask - elif torch.all(attention_mask == 1): - if is_tracing: - pass - elif query_length == 1: + elif not is_tracing and torch.all(attention_mask == 1): + if query_length == 1: # For query_length == 1, causal attention and bi-directional attention are the same. attention_mask = None elif key_value_length == query_length: @@ -405,7 +404,11 @@ def _prepare_4d_causal_attention_mask_for_sdpa( # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 - if query_length > 1: + # + # This fix is not applied in case we are tracing with torch.jit.trace or symbolic_trace, as _unmask_unattended has a data-dependent + # controlflow that can not be captured properly. + # TODO: _unmask_unattended does not work either with torch.compile when using fullgraph=True. We should find a way to detect this case. + if query_length > 1 and not is_tracing: expanded_4d_mask = AttentionMaskConverter._unmask_unattended( expanded_4d_mask, attention_mask, unmasked_value=0.0 ) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 2b1d95b651..36feadff3c 100755 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -131,6 +131,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [ "gptj", "hubert", "layoutlm", + "llama", "lxmert", "m2m_100", "marian", @@ -156,6 +157,8 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [ # "xlnet", ] +_FX_SUPPORTED_MODELS_WITH_KV_CACHE = ["llama", "opt"] + _REGULAR_SUPPORTED_MODELS = [] for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS: if isinstance(item, dict): @@ -514,6 +517,14 @@ def torch_nn_functional_one_hot(tensor, num_classes=-1): return torch.empty(shape, device="meta") +def torch_nn_functional_scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None +): + target_length = query.shape[-2] + head_dim = value.shape[-1] + return torch.empty((*query.shape[:-2], target_length, head_dim), device="meta") + + def torch_nn_mseloss(self, input, target): if self.reduction == "none": shape = target.shape @@ -597,6 +608,7 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = { torch.Tensor.unsqueeze: torch_tensor_unsqueeze, torch.unique_consecutive: torch_unique_consecutive, torch.nn.functional.one_hot: torch_nn_functional_one_hot, + torch.nn.functional.scaled_dot_product_attention: torch_nn_functional_scaled_dot_product_attention, torch.nn.MSELoss: torch_nn_mseloss, torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss, torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss, @@ -868,6 +880,23 @@ class HFTracer(Tracer): inputs_dict[input_name] = torch.zeros(batch_size, seq_length, dtype=torch.float, device=device) elif "mask" in input_name or "ids" in input_name: inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device) + elif "past_key_values" in input_name: + if model.config.model_type not in _FX_SUPPORTED_MODELS_WITH_KV_CACHE: + raise NotImplementedError( + f"Symbolic trace with past_key_values input is not supported yet for the model {model.config.model_type}. Please open an issue or a PR in Transformers repository if you would like to see the support added." + ) + num_heads = model.config.num_attention_heads + head_dim = model.config.hidden_size // model.config.num_attention_heads + + cache_shape = (shape[0], num_heads, 0, head_dim) + pkv = tuple( + ( + torch.rand(cache_shape, dtype=torch.float, device=device), + torch.rand(cache_shape, dtype=torch.float, device=device), + ) + for i in range(model.config.num_hidden_layers) + ) + inputs_dict[input_name] = pkv else: shape_with_hidden_size = shape + [model.config.hidden_size] inputs_dict[input_name] = torch.zeros(shape_with_hidden_size, dtype=torch.float, device=device) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 427f94f873..2a52590369 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -292,6 +292,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ) test_headmasking = False test_pruning = False + fx_compatible = True def setUp(self): self.model_tester = LlamaModelTester(self) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 15c610563d..69cf04d37a 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -118,7 +118,7 @@ if is_flax_available(): ) if is_torch_fx_available(): - from transformers.utils.fx import symbolic_trace + from transformers.utils.fx import _FX_SUPPORTED_MODELS_WITH_KV_CACHE, symbolic_trace def _config_zero_init(config): @@ -1004,7 +1004,9 @@ class ModelTesterMixin: def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False): if not is_torch_fx_available() or not self.fx_compatible: - return + self.skipTest( + f"Either torch.fx is not available, or the model type {config.model_type} is not compatible with torch.fx" + ) configs_no_init = _config_zero_init(config) # To be sure we have no Nan configs_no_init.return_dict = False @@ -1060,6 +1062,26 @@ class ModelTesterMixin: if end_positions is not None: input_names.append("end_positions") + if model.config.model_type in _FX_SUPPORTED_MODELS_WITH_KV_CACHE: + input_names.append("past_key_values") + + # Generally model_tester.prepare_config_and_inputs_for_common seem not to generate past key values inputs. + if "past_key_values" not in inputs: + batch_size = inputs[next(iter(inputs))].shape[0] + num_heads = model.config.num_attention_heads + head_dim = model.config.hidden_size // model.config.num_attention_heads + + cache_shape = (batch_size, num_heads, 0, head_dim) + pkv = tuple( + ( + torch.rand(cache_shape, dtype=torch.float, device=torch_device), + torch.rand(cache_shape, dtype=torch.float, device=torch_device), + ) + for i in range(model.config.num_hidden_layers) + ) + + inputs["past_key_values"] = pkv + filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} input_names = list(filtered_inputs.keys()) @@ -1069,8 +1091,10 @@ class ModelTesterMixin: model.config.problem_type = "single_label_classification" traced_model = symbolic_trace(model, input_names) - traced_output = traced_model(**filtered_inputs) - model_output = model(**filtered_inputs) + + with torch.no_grad(): + traced_output = traced_model(**filtered_inputs) + model_output = model(**filtered_inputs) except Exception as e: self.fail(f"Couldn't trace module: {e}")