enhance require_deterministic_for_xpu (#37437)
* enhance require_deterministic_for_xpu Signed-off-by: YAO Matrix <matrix.yao@intel.com> * fix style Signed-off-by: YAO Matrix <matrix.yao@intel.com> * fix style Signed-off-by: YAO Matrix <matrix.yao@intel.com> --------- Signed-off-by: YAO Matrix <matrix.yao@intel.com>
This commit is contained in:
@@ -139,7 +139,6 @@ from .utils import (
|
|||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_torch_bf16_available_on_device,
|
is_torch_bf16_available_on_device,
|
||||||
is_torch_bf16_gpu_available,
|
is_torch_bf16_gpu_available,
|
||||||
is_torch_deterministic,
|
|
||||||
is_torch_fp16_available_on_device,
|
is_torch_fp16_available_on_device,
|
||||||
is_torch_greater_or_equal,
|
is_torch_greater_or_equal,
|
||||||
is_torch_hpu_available,
|
is_torch_hpu_available,
|
||||||
@@ -1073,12 +1072,19 @@ def require_torch_bf16_gpu(test_case):
|
|||||||
|
|
||||||
|
|
||||||
def require_deterministic_for_xpu(test_case):
|
def require_deterministic_for_xpu(test_case):
|
||||||
if is_torch_xpu_available():
|
@wraps(test_case)
|
||||||
return unittest.skipUnless(is_torch_deterministic(), "test requires torch to use deterministic algorithms")(
|
def wrapper(*args, **kwargs):
|
||||||
test_case
|
if is_torch_xpu_available():
|
||||||
)
|
original_state = torch.are_deterministic_algorithms_enabled()
|
||||||
else:
|
try:
|
||||||
return test_case
|
torch.use_deterministic_algorithms(True)
|
||||||
|
return test_case(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
torch.use_deterministic_algorithms(original_state)
|
||||||
|
else:
|
||||||
|
return test_case(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def require_torch_tf32(test_case):
|
def require_torch_tf32(test_case):
|
||||||
|
|||||||
@@ -936,6 +936,7 @@ class BertGenerationEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCa
|
|||||||
}
|
}
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
|
@require_deterministic_for_xpu
|
||||||
def test_roberta2roberta_summarization(self):
|
def test_roberta2roberta_summarization(self):
|
||||||
model = EncoderDecoderModel.from_pretrained("google/roberta2roberta_L-24_bbc")
|
model = EncoderDecoderModel.from_pretrained("google/roberta2roberta_L-24_bbc")
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
@@ -1080,6 +1081,7 @@ class GPT2EncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
|
@require_deterministic_for_xpu
|
||||||
def test_bert2gpt2_summarization(self):
|
def test_bert2gpt2_summarization(self):
|
||||||
model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2gpt2-cnn_dailymail-fp16")
|
model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2gpt2-cnn_dailymail-fp16")
|
||||||
|
|
||||||
|
|||||||
@@ -634,6 +634,7 @@ class Speech2TextBertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||||||
def test_encoder_decoder_model_from_pretrained_configs(self):
|
def test_encoder_decoder_model_from_pretrained_configs(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@require_deterministic_for_xpu
|
||||||
@unittest.skip(reason="Cannot save full model as Speech2TextModel != Speech2TextEncoder")
|
@unittest.skip(reason="Cannot save full model as Speech2TextModel != Speech2TextEncoder")
|
||||||
def test_save_and_load_from_pretrained(self):
|
def test_save_and_load_from_pretrained(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user