Fix torch.fx symbolic tracing for LLama (#30047)
* [WIP] fix fx * [WIP] fix fx * [WIP] fix fx * [WIP] fix fx * [WIP] fix fx * Apply changes to other models
This commit is contained in:
@@ -908,7 +908,9 @@ class CohereModel(CoherePreTrainedModel):
|
|||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
causal_mask = self._update_causal_mask(
|
||||||
|
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
|
||||||
|
)
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
@@ -976,7 +978,7 @@ class CohereModel(CoherePreTrainedModel):
|
|||||||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||||
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||||
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||||
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
|
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length):
|
||||||
if self.config._attn_implementation == "flash_attention_2":
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
if attention_mask is not None and 0.0 in attention_mask:
|
if attention_mask is not None and 0.0 in attention_mask:
|
||||||
return attention_mask
|
return attention_mask
|
||||||
@@ -989,7 +991,7 @@ class CohereModel(CoherePreTrainedModel):
|
|||||||
target_length = self.config.max_position_embeddings
|
target_length = self.config.max_position_embeddings
|
||||||
else: # dynamic cache
|
else: # dynamic cache
|
||||||
target_length = (
|
target_length = (
|
||||||
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
|
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
|
||||||
)
|
)
|
||||||
|
|
||||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||||
|
|||||||
@@ -888,7 +888,9 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
causal_mask = self._update_causal_mask(
|
||||||
|
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
|
||||||
|
)
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
@@ -962,7 +964,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||||||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||||
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||||
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||||
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
|
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length):
|
||||||
if self.config._attn_implementation == "flash_attention_2":
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
if attention_mask is not None and 0.0 in attention_mask:
|
if attention_mask is not None and 0.0 in attention_mask:
|
||||||
return attention_mask
|
return attention_mask
|
||||||
@@ -975,7 +977,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||||||
target_length = self.config.max_position_embeddings
|
target_length = self.config.max_position_embeddings
|
||||||
else: # dynamic cache
|
else: # dynamic cache
|
||||||
target_length = (
|
target_length = (
|
||||||
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
|
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
|
||||||
)
|
)
|
||||||
|
|
||||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||||
|
|||||||
@@ -987,7 +987,9 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
causal_mask = self._update_causal_mask(
|
||||||
|
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
|
||||||
|
)
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
@@ -1055,7 +1057,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||||
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||||
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||||
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
|
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length):
|
||||||
if self.config._attn_implementation == "flash_attention_2":
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
if attention_mask is not None and 0.0 in attention_mask:
|
if attention_mask is not None and 0.0 in attention_mask:
|
||||||
return attention_mask
|
return attention_mask
|
||||||
@@ -1068,7 +1070,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
target_length = self.config.max_position_embeddings
|
target_length = self.config.max_position_embeddings
|
||||||
else: # dynamic cache
|
else: # dynamic cache
|
||||||
target_length = (
|
target_length = (
|
||||||
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
|
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
|
||||||
)
|
)
|
||||||
|
|
||||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||||
|
|||||||
@@ -260,11 +260,14 @@ def torch_arange(*args, **kwargs):
|
|||||||
|
|
||||||
def torch_full(*args, **kwargs):
|
def torch_full(*args, **kwargs):
|
||||||
args = list(args)
|
args = list(args)
|
||||||
if isinstance(args[1], torch.Tensor) and args[1].device == torch.device("meta"):
|
# We set the fill value to 1 as its value is not important as long as it's not a tensor on the `meta` device.
|
||||||
args[1] = 1 # Any value.
|
if len(args) > 1:
|
||||||
|
args[1] = 1
|
||||||
|
else:
|
||||||
|
kwargs["fill_value"] = 1
|
||||||
kwargs_without_device = dict(kwargs)
|
kwargs_without_device = dict(kwargs)
|
||||||
kwargs_without_device.pop("device", None)
|
kwargs_without_device.pop("device", None)
|
||||||
return torch.full(*args, **kwargs_without_device)
|
return torch.full(*args, **kwargs_without_device, device="meta")
|
||||||
|
|
||||||
|
|
||||||
def torch_cat(tensors, dim=None, axis=None, *, out=None):
|
def torch_cat(tensors, dim=None, axis=None, *, out=None):
|
||||||
|
|||||||
@@ -283,9 +283,7 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
)
|
)
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
fx_compatible = (
|
fx_compatible = True
|
||||||
False # FIXME @michaelbenayoun or @fxmarty from https://github.com/huggingface/transformers/pull/29753
|
|
||||||
)
|
|
||||||
|
|
||||||
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
|
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
|
||||||
# This is because we are hitting edge cases with the causal_mask buffer
|
# This is because we are hitting edge cases with the causal_mask buffer
|
||||||
|
|||||||
@@ -305,9 +305,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
)
|
)
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
fx_compatible = (
|
fx_compatible = True
|
||||||
False # FIXME @michaelbenayoun or @fxmarty from https://github.com/huggingface/transformers/pull/29753
|
|
||||||
)
|
|
||||||
|
|
||||||
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
|
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
|
||||||
# This is because we are hitting edge cases with the causal_mask buffer
|
# This is because we are hitting edge cases with the causal_mask buffer
|
||||||
|
|||||||
Reference in New Issue
Block a user