remove SharedDDP as it is deprecated (#25702)

* remove SharedDDP as it was drepracated

* apply review suggestion

* make style

* Oops,forgot to remove the compute_loss context manager in Seq2SeqTrainer.

* remove the unnecessary conditional statement

* keep the logic of IPEX

* clean code

* mix precision setup & make fixup

---------

Co-authored-by: statelesshz <jihuazhong1@huawei.com>
This commit is contained in:
statelesshz
2023-10-06 22:03:11 +08:00
committed by GitHub
parent e840aa67e8
commit 27597fea07
11 changed files with 39 additions and 299 deletions

View File

@@ -16,7 +16,6 @@ import math
import os
import re
import sys
import unittest
from pathlib import Path
from typing import Tuple
from unittest.mock import patch
@@ -32,7 +31,6 @@ from transformers.testing_utils import (
get_torch_dist_unique_port,
require_apex,
require_bitsandbytes,
require_fairscale,
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
@@ -105,36 +103,6 @@ class TestTrainerExt(TestCasePlus):
def test_run_seq2seq_ddp(self):
self.run_seq2seq_quick(distributed=True)
# test --sharded_ddp w/o --fp16
@unittest.skip("Requires an update of the env running those tests")
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_sharded_ddp(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple")
# test --sharded_ddp w/ --fp16
@unittest.skip("Requires an update of the env running those tests")
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_sharded_ddp_fp16(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16")
# test --sharded_ddp zero_dp_2 w/o --fp16
@unittest.skip("Requires an update of the env running those tests")
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_fully_sharded_ddp(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero_dp_2", predict_with_generate=False)
# test --sharded_ddp zero_dp_2 w/ --fp16
@unittest.skip("Requires an update of the env running those tests")
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_fully_sharded_ddp_fp16(self):
self.run_seq2seq_quick(
distributed=True, extra_args_str="--sharded_ddp zero_dp_2 --fp16", predict_with_generate=False
)
@require_apex
@require_torch_gpu
def test_run_seq2seq_apex(self):