Fix test_compile_static_cache (#30991)
* fix * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -27,6 +27,7 @@ from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
require_bitsandbytes,
|
||||
require_flash_attn,
|
||||
require_read_token,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_sdpa,
|
||||
@@ -658,12 +659,16 @@ class MistralIntegrationTest(unittest.TestCase):
|
||||
gc.collect()
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_compile_static_cache(self):
|
||||
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
|
||||
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
|
||||
if version.parse(torch.__version__) < version.parse("2.3.0"):
|
||||
self.skipTest("This test requires torch >= 2.3 to run.")
|
||||
|
||||
if self.cuda_compute_capability_major_version == 7:
|
||||
self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU.")
|
||||
|
||||
NUM_TOKENS_TO_GENERATE = 40
|
||||
EXPECTED_TEXT_COMPLETION = {
|
||||
8: [
|
||||
|
||||
Reference in New Issue
Block a user