symbolic_trace: add past_key_values, llama, sdpa support (#28447)

* torch.fx: add pkv, llama, sdpa support

* Update src/transformers/models/opt/modeling_opt.py

* remove spaces

* trigger ci

* use explicit variable names
This commit is contained in:
fxmarty
2024-01-17 11:50:53 +01:00
committed by GitHub
parent 09eb11a1bd
commit a6adc05e6b
4 changed files with 68 additions and 11 deletions

View File

@@ -292,6 +292,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
)
test_headmasking = False
test_pruning = False
fx_compatible = True
def setUp(self):
self.model_tester = LlamaModelTester(self)