Make Gemma work with torch.compile (#30775)
* fix * [run-slow] gemma * add test * add `test_compile_static_cache` * fix * style * remove subprocess * use attribute * fix * style * update * [run-slow] dbrx,gemma,jetmoe,phi3,recurrent_gemma --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -312,6 +312,9 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
# This is because we are hitting edge cases with the causal_mask buffer
|
||||
model_split_percents = [0.5, 0.7, 0.8]
|
||||
|
||||
# used in `test_torch_compile`
|
||||
_torch_compile_test_ckpt = "meta-llama/Llama-2-7b-hf"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = LlamaModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=LlamaConfig, hidden_size=37)
|
||||
|
||||
Reference in New Issue
Block a user