start using training_args.parallel_mode (#8882)

This commit is contained in:
Stas Bekman
2020-12-01 11:40:36 -08:00
committed by GitHub
parent b08843cf4d
commit 379005c9d2
2 changed files with 4 additions and 2 deletions

View File

@@ -18,6 +18,7 @@ from transformers.optimization import (
get_polynomial_decay_schedule_with_warmup,
)
from transformers.trainer_pt_utils import get_tpu_sampler
from transformers.training_args import ParallelMode
logger = logging.get_logger(__name__)
@@ -123,7 +124,7 @@ class Seq2SeqTrainer(Trainer):
if self.args.sortish_sampler:
self.train_dataset.make_sortish_sampler(
self.args.per_device_train_batch_size,
distributed=(self.args.local_rank != -1),
distributed=(self.args.parallel_mode == ParallelMode.DISTRIBUTED),
)
return (