diff --git a/tests/extended/test_trainer_ext.py b/tests/extended/test_trainer_ext.py index a0a328cf09..dca3604e13 100644 --- a/tests/extended/test_trainer_ext.py +++ b/tests/extended/test_trainer_ext.py @@ -28,6 +28,7 @@ from transformers.testing_utils import ( execute_subprocess_async, get_gpu_count, get_torch_dist_unique_port, + require_torch, require_torch_gpu, require_torch_multi_gpu, require_torch_non_multi_gpu, @@ -69,6 +70,7 @@ def require_apex(test_case): return test_case +@require_torch class TestTrainerExt(TestCasePlus): def run_seq2seq_quick( self,