Fix torch.fx symbolic tracing for LLama (#30047)

* [WIP] fix fx

* [WIP] fix fx

* [WIP] fix fx

* [WIP] fix fx

* [WIP] fix fx

* Apply changes to other models
This commit is contained in:
Michael Benayoun
2024-04-05 15:14:09 +02:00
committed by GitHub
parent 48795317a2
commit 17cd7a9d28
6 changed files with 23 additions and 18 deletions

View File

@@ -283,9 +283,7 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
)
test_headmasking = False
test_pruning = False
fx_compatible = (
False # FIXME @michaelbenayoun or @fxmarty from https://github.com/huggingface/transformers/pull/29753
)
fx_compatible = True
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer

View File

@@ -305,9 +305,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
)
test_headmasking = False
test_pruning = False
fx_compatible = (
False # FIXME @michaelbenayoun or @fxmarty from https://github.com/huggingface/transformers/pull/29753
)
fx_compatible = True
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer