Make Gemma work with torch.compile (#30775)

* fix

* [run-slow] gemma

* add test

* add `test_compile_static_cache`

* fix

* style

* remove subprocess

* use attribute

* fix

* style

* update

* [run-slow] dbrx,gemma,jetmoe,phi3,recurrent_gemma

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2024-05-16 13:41:33 +02:00
committed by GitHub
parent 0753134f4d
commit 1b3dba9417
8 changed files with 110 additions and 26 deletions

View File

@@ -312,6 +312,9 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.7, 0.8]
# used in `test_torch_compile`
_torch_compile_test_ckpt = "meta-llama/Llama-2-7b-hf"
def setUp(self):
self.model_tester = LlamaModelTester(self)
self.config_tester = ConfigTester(self, config_class=LlamaConfig, hidden_size=37)