diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 49ea8ef177..48d4cd30ed 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -139,7 +139,6 @@ from .utils import ( is_torch_available, is_torch_bf16_available_on_device, is_torch_bf16_gpu_available, - is_torch_deterministic, is_torch_fp16_available_on_device, is_torch_greater_or_equal, is_torch_hpu_available, @@ -1073,12 +1072,19 @@ def require_torch_bf16_gpu(test_case): def require_deterministic_for_xpu(test_case): - if is_torch_xpu_available(): - return unittest.skipUnless(is_torch_deterministic(), "test requires torch to use deterministic algorithms")( - test_case - ) - else: - return test_case + @wraps(test_case) + def wrapper(*args, **kwargs): + if is_torch_xpu_available(): + original_state = torch.are_deterministic_algorithms_enabled() + try: + torch.use_deterministic_algorithms(True) + return test_case(*args, **kwargs) + finally: + torch.use_deterministic_algorithms(original_state) + else: + return test_case(*args, **kwargs) + + return wrapper def require_torch_tf32(test_case): diff --git a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py index 6d2286bae3..a5f7afb16f 100644 --- a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py +++ b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py @@ -936,6 +936,7 @@ class BertGenerationEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCa } @slow + @require_deterministic_for_xpu def test_roberta2roberta_summarization(self): model = EncoderDecoderModel.from_pretrained("google/roberta2roberta_L-24_bbc") model.to(torch_device) @@ -1080,6 +1081,7 @@ class GPT2EncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase): pass @slow + @require_deterministic_for_xpu def test_bert2gpt2_summarization(self): model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2gpt2-cnn_dailymail-fp16") diff --git a/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py b/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py index b7c3b97e5c..7ac8d5631d 100644 --- a/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py +++ b/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py @@ -634,6 +634,7 @@ class Speech2TextBertModelTest(EncoderDecoderMixin, unittest.TestCase): def test_encoder_decoder_model_from_pretrained_configs(self): pass + @require_deterministic_for_xpu @unittest.skip(reason="Cannot save full model as Speech2TextModel != Speech2TextEncoder") def test_save_and_load_from_pretrained(self): pass