[Utils] torch version checks optionally accept dev versions (#36847)

This commit is contained in:
Joao Gante
2025-03-25 10:58:58 +00:00
committed by GitHub
parent 80b4c5dcc9
commit bc1c90a755
4 changed files with 35 additions and 22 deletions

View File

@@ -605,7 +605,7 @@ class CacheIntegrationTest(unittest.TestCase):
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
device = model.device
if not is_torch_greater_or_equal("2.7") and device.type == "xpu":
if not is_torch_greater_or_equal("2.7", accept_dev=True) and device.type == "xpu":
self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")
input_text = "Fun fact:"
@@ -633,7 +633,7 @@ class CacheIntegrationTest(unittest.TestCase):
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
device = model.device
if not is_torch_greater_or_equal("2.7") and device.type == "xpu":
if not is_torch_greater_or_equal("2.7", accept_dev=True) and device.type == "xpu":
self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")
input_text = "Fun fact:"