remove SharedDDP as it is deprecated (#25702)
* remove SharedDDP as it was drepracated * apply review suggestion * make style * Oops,forgot to remove the compute_loss context manager in Seq2SeqTrainer. * remove the unnecessary conditional statement * keep the logic of IPEX * clean code * mix precision setup & make fixup --------- Co-authored-by: statelesshz <jihuazhong1@huawei.com>
This commit is contained in:
@@ -16,7 +16,6 @@ import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
from unittest.mock import patch
|
||||
@@ -32,7 +31,6 @@ from transformers.testing_utils import (
|
||||
get_torch_dist_unique_port,
|
||||
require_apex,
|
||||
require_bitsandbytes,
|
||||
require_fairscale,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
@@ -105,36 +103,6 @@ class TestTrainerExt(TestCasePlus):
|
||||
def test_run_seq2seq_ddp(self):
|
||||
self.run_seq2seq_quick(distributed=True)
|
||||
|
||||
# test --sharded_ddp w/o --fp16
|
||||
@unittest.skip("Requires an update of the env running those tests")
|
||||
@require_torch_multi_gpu
|
||||
@require_fairscale
|
||||
def test_run_seq2seq_sharded_ddp(self):
|
||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple")
|
||||
|
||||
# test --sharded_ddp w/ --fp16
|
||||
@unittest.skip("Requires an update of the env running those tests")
|
||||
@require_torch_multi_gpu
|
||||
@require_fairscale
|
||||
def test_run_seq2seq_sharded_ddp_fp16(self):
|
||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16")
|
||||
|
||||
# test --sharded_ddp zero_dp_2 w/o --fp16
|
||||
@unittest.skip("Requires an update of the env running those tests")
|
||||
@require_torch_multi_gpu
|
||||
@require_fairscale
|
||||
def test_run_seq2seq_fully_sharded_ddp(self):
|
||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero_dp_2", predict_with_generate=False)
|
||||
|
||||
# test --sharded_ddp zero_dp_2 w/ --fp16
|
||||
@unittest.skip("Requires an update of the env running those tests")
|
||||
@require_torch_multi_gpu
|
||||
@require_fairscale
|
||||
def test_run_seq2seq_fully_sharded_ddp_fp16(self):
|
||||
self.run_seq2seq_quick(
|
||||
distributed=True, extra_args_str="--sharded_ddp zero_dp_2 --fp16", predict_with_generate=False
|
||||
)
|
||||
|
||||
@require_apex
|
||||
@require_torch_gpu
|
||||
def test_run_seq2seq_apex(self):
|
||||
|
||||
Reference in New Issue
Block a user