From 17cd7a9d28e12ed3f1623d1193f0b3a2ad4aca92 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 5 Apr 2024 15:14:09 +0200 Subject: [PATCH] Fix `torch.fx` symbolic tracing for LLama (#30047) * [WIP] fix fx * [WIP] fix fx * [WIP] fix fx * [WIP] fix fx * [WIP] fix fx * Apply changes to other models --- src/transformers/models/cohere/modeling_cohere.py | 8 +++++--- src/transformers/models/gemma/modeling_gemma.py | 8 +++++--- src/transformers/models/llama/modeling_llama.py | 8 +++++--- src/transformers/utils/fx.py | 9 ++++++--- tests/models/cohere/test_modeling_cohere.py | 4 +--- tests/models/llama/test_modeling_llama.py | 4 +--- 6 files changed, 23 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 41bae6db65..95a7d76827 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -908,7 +908,9 @@ class CohereModel(CoherePreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1] + ) # embed positions hidden_states = inputs_embeds @@ -976,7 +978,7 @@ class CohereModel(CoherePreTrainedModel): # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - def _update_causal_mask(self, attention_mask, input_tensor, cache_position): + def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask @@ -989,7 +991,7 @@ class CohereModel(CoherePreTrainedModel): target_length = self.config.max_position_embeddings else: # dynamic cache target_length = ( - attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1 + attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1 ) causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 2d93c43425..c8b9b11c55 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -888,7 +888,9 @@ class GemmaModel(GemmaPreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1] + ) # embed positions hidden_states = inputs_embeds @@ -962,7 +964,7 @@ class GemmaModel(GemmaPreTrainedModel): # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - def _update_causal_mask(self, attention_mask, input_tensor, cache_position): + def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask @@ -975,7 +977,7 @@ class GemmaModel(GemmaPreTrainedModel): target_length = self.config.max_position_embeddings else: # dynamic cache target_length = ( - attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1 + attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1 ) causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8d0baf63c7..e1afb61be0 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -987,7 +987,9 @@ class LlamaModel(LlamaPreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1] + ) # embed positions hidden_states = inputs_embeds @@ -1055,7 +1057,7 @@ class LlamaModel(LlamaPreTrainedModel): # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - def _update_causal_mask(self, attention_mask, input_tensor, cache_position): + def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask @@ -1068,7 +1070,7 @@ class LlamaModel(LlamaPreTrainedModel): target_length = self.config.max_position_embeddings else: # dynamic cache target_length = ( - attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1 + attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1 ) causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index fd2b1512b2..df0aba8d5d 100755 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -260,11 +260,14 @@ def torch_arange(*args, **kwargs): def torch_full(*args, **kwargs): args = list(args) - if isinstance(args[1], torch.Tensor) and args[1].device == torch.device("meta"): - args[1] = 1 # Any value. + # We set the fill value to 1 as its value is not important as long as it's not a tensor on the `meta` device. + if len(args) > 1: + args[1] = 1 + else: + kwargs["fill_value"] = 1 kwargs_without_device = dict(kwargs) kwargs_without_device.pop("device", None) - return torch.full(*args, **kwargs_without_device) + return torch.full(*args, **kwargs_without_device, device="meta") def torch_cat(tensors, dim=None, axis=None, *, out=None): diff --git a/tests/models/cohere/test_modeling_cohere.py b/tests/models/cohere/test_modeling_cohere.py index 883eb92e8b..3e86ffe9d9 100644 --- a/tests/models/cohere/test_modeling_cohere.py +++ b/tests/models/cohere/test_modeling_cohere.py @@ -283,9 +283,7 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix ) test_headmasking = False test_pruning = False - fx_compatible = ( - False # FIXME @michaelbenayoun or @fxmarty from https://github.com/huggingface/transformers/pull/29753 - ) + fx_compatible = True # Need to use `0.8` instead of `0.9` for `test_cpu_offload` # This is because we are hitting edge cases with the causal_mask buffer diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index e0a3990bd8..0fb4087dba 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -305,9 +305,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ) test_headmasking = False test_pruning = False - fx_compatible = ( - False # FIXME @michaelbenayoun or @fxmarty from https://github.com/huggingface/transformers/pull/29753 - ) + fx_compatible = True # Need to use `0.8` instead of `0.9` for `test_cpu_offload` # This is because we are hitting edge cases with the causal_mask buffer