From df848acc5d0ff267c6c9d1c3cfee0536871600d3 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Mon, 3 Jun 2024 15:16:28 +0200 Subject: [PATCH] Fix `test_compile_static_cache` (#30991) * fix * fix * fix * fix --------- Co-authored-by: ydshieh --- tests/models/llama/test_modeling_llama.py | 7 ++----- tests/models/mistral/test_modeling_mistral.py | 5 +++++ 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 61b33b3ec9..3e84552ab7 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -729,11 +729,8 @@ class LlamaIntegrationTest(unittest.TestCase): "my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", ], 7: [ - "Simply put, the theory of relativity states that 1. surely nothing is faster than light.\nThe theory " - "goes that nothing travels faster than light, but the faster you go, the slower everything else will " - "be.\nThe theory of relativity", - "My favorite all time favorite condiment is ketchup. I love it on hamburgers, hot dogs, fries, eggs, " - "and even on a good old fashioned cheeseburger. I love it on everything. I love it so", + "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe theory of relativ", + "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", ], 9: [ "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial" diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index cead890586..9d3570bd43 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -27,6 +27,7 @@ from transformers.testing_utils import ( is_flaky, require_bitsandbytes, require_flash_attn, + require_read_token, require_torch, require_torch_gpu, require_torch_sdpa, @@ -658,12 +659,16 @@ class MistralIntegrationTest(unittest.TestCase): gc.collect() @slow + @require_read_token def test_compile_static_cache(self): # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 # work as intended. See https://github.com/pytorch/pytorch/issues/121943 if version.parse(torch.__version__) < version.parse("2.3.0"): self.skipTest("This test requires torch >= 2.3 to run.") + if self.cuda_compute_capability_major_version == 7: + self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU.") + NUM_TOKENS_TO_GENERATE = 40 EXPECTED_TEXT_COMPLETION = { 8: [