Introduce PartialState as the device handler in the Trainer (#22752)

* Use accelerate for device management

* Add accelerate to setup


Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Zachary Mueller
2023-04-17 15:09:45 -04:00
committed by GitHub
parent 50caa20628
commit 03462875cc
4 changed files with 56 additions and 140 deletions

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from typing import Dict
from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available
@@ -23,6 +22,7 @@ from transformers.testing_utils import (
require_torch_multi_gpu,
require_torch_neuroncore,
)
from transformers.training_args import ParallelMode
from transformers.utils import logging
@@ -66,15 +66,13 @@ if is_torch_available():
class TestTrainerDistributedNeuronCore(TestCasePlus):
@require_torch_neuroncore
def test_trainer(self):
distributed_args = f"""
-m torch.distributed.run
--nproc_per_node=2
distributed_args = f"""--nproc_per_node=2
--master_port={get_torch_dist_unique_port()}
{self.test_file_dir}/test_trainer_distributed.py
""".split()
output_dir = self.get_auto_remove_tmp_dir()
args = f"--output_dir {output_dir}".split()
cmd = [sys.executable] + distributed_args + args
cmd = ["torchrun"] + distributed_args + args
execute_subprocess_async(cmd, env=self.get_env())
# successful return here == success - any errors would have caused an error in the sub-call
@@ -82,15 +80,13 @@ class TestTrainerDistributedNeuronCore(TestCasePlus):
class TestTrainerDistributed(TestCasePlus):
@require_torch_multi_gpu
def test_trainer(self):
distributed_args = f"""
-m torch.distributed.run
--nproc_per_node={torch.cuda.device_count()}
distributed_args = f"""--nproc_per_node={torch.cuda.device_count()}
--master_port={get_torch_dist_unique_port()}
{self.test_file_dir}/test_trainer_distributed.py
""".split()
output_dir = self.get_auto_remove_tmp_dir()
args = f"--output_dir {output_dir}".split()
cmd = [sys.executable] + distributed_args + args
cmd = ["torchrun"] + distributed_args + args
execute_subprocess_async(cmd, env=self.get_env())
# successful return here == success - any errors would have caused an error in the sub-call
@@ -105,7 +101,7 @@ if __name__ == "__main__":
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
f"distributed training: {training_args.local_rank != -1}"
f"distributed training: {training_args.parallel_mode != ParallelMode.NOT_DISTRIBUTED}"
)
# Essentially, what we want to verify in the distributed case is that we get all samples back,