[Utils] torch version checks optionally accept dev versions (#36847)
This commit is contained in:
@@ -605,7 +605,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
|
||||
device = model.device
|
||||
|
||||
if not is_torch_greater_or_equal("2.7") and device.type == "xpu":
|
||||
if not is_torch_greater_or_equal("2.7", accept_dev=True) and device.type == "xpu":
|
||||
self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")
|
||||
|
||||
input_text = "Fun fact:"
|
||||
@@ -633,7 +633,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
|
||||
device = model.device
|
||||
|
||||
if not is_torch_greater_or_equal("2.7") and device.type == "xpu":
|
||||
if not is_torch_greater_or_equal("2.7", accept_dev=True) and device.type == "xpu":
|
||||
self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")
|
||||
|
||||
input_text = "Fun fact:"
|
||||
|
||||
Reference in New Issue
Block a user