From 379005c9d202ba9e38a759728bcfc7806759a0bb Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 1 Dec 2020 11:40:36 -0800 Subject: [PATCH] start using training_args.parallel_mode (#8882) --- examples/seq2seq/finetune_trainer.py | 3 ++- examples/seq2seq/seq2seq_trainer.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/seq2seq/finetune_trainer.py b/examples/seq2seq/finetune_trainer.py index 1bb3471139..473de92734 100755 --- a/examples/seq2seq/finetune_trainer.py +++ b/examples/seq2seq/finetune_trainer.py @@ -11,6 +11,7 @@ from seq2seq_trainer import Seq2SeqTrainer from seq2seq_training_args import Seq2SeqTrainingArguments from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, HfArgumentParser, MBartTokenizer, set_seed from transformers.trainer_utils import EvaluationStrategy, is_main_process +from transformers.training_args import ParallelMode from utils import ( Seq2SeqDataCollator, Seq2SeqDataset, @@ -132,7 +133,7 @@ def main(): training_args.local_rank, training_args.device, training_args.n_gpu, - bool(training_args.local_rank != -1), + bool(training_args.parallel_mode == ParallelMode.DISTRIBUTED), training_args.fp16, ) # Set the verbosity to info of the Transformers logger (on main process only): diff --git a/examples/seq2seq/seq2seq_trainer.py b/examples/seq2seq/seq2seq_trainer.py index 4d102990e4..684407a6bc 100644 --- a/examples/seq2seq/seq2seq_trainer.py +++ b/examples/seq2seq/seq2seq_trainer.py @@ -18,6 +18,7 @@ from transformers.optimization import ( get_polynomial_decay_schedule_with_warmup, ) from transformers.trainer_pt_utils import get_tpu_sampler +from transformers.training_args import ParallelMode logger = logging.get_logger(__name__) @@ -123,7 +124,7 @@ class Seq2SeqTrainer(Trainer): if self.args.sortish_sampler: self.train_dataset.make_sortish_sampler( self.args.per_device_train_batch_size, - distributed=(self.args.local_rank != -1), + distributed=(self.args.parallel_mode == ParallelMode.DISTRIBUTED), ) return (