Fix llama + gemma accelete tests (#29380)

This commit is contained in:
Marc Sun
2024-03-01 10:32:36 -05:00
committed by GitHub
parent 15f8296a9b
commit cec773345a
2 changed files with 8 additions and 0 deletions

View File

@@ -302,6 +302,10 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
test_pruning = False
fx_compatible = True
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.7, 0.8]
def setUp(self):
self.model_tester = LlamaModelTester(self)
self.config_tester = ConfigTester(self, config_class=LlamaConfig, hidden_size=37)