The extended trainer tests should require torch (#12650)
This commit is contained in:
@@ -28,6 +28,7 @@ from transformers.testing_utils import (
|
|||||||
execute_subprocess_async,
|
execute_subprocess_async,
|
||||||
get_gpu_count,
|
get_gpu_count,
|
||||||
get_torch_dist_unique_port,
|
get_torch_dist_unique_port,
|
||||||
|
require_torch,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
require_torch_non_multi_gpu,
|
require_torch_non_multi_gpu,
|
||||||
@@ -69,6 +70,7 @@ def require_apex(test_case):
|
|||||||
return test_case
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
class TestTrainerExt(TestCasePlus):
|
class TestTrainerExt(TestCasePlus):
|
||||||
def run_seq2seq_quick(
|
def run_seq2seq_quick(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user