start using training_args.parallel_mode (#8882)
This commit is contained in:
@@ -11,6 +11,7 @@ from seq2seq_trainer import Seq2SeqTrainer
|
|||||||
from seq2seq_training_args import Seq2SeqTrainingArguments
|
from seq2seq_training_args import Seq2SeqTrainingArguments
|
||||||
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, HfArgumentParser, MBartTokenizer, set_seed
|
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, HfArgumentParser, MBartTokenizer, set_seed
|
||||||
from transformers.trainer_utils import EvaluationStrategy, is_main_process
|
from transformers.trainer_utils import EvaluationStrategy, is_main_process
|
||||||
|
from transformers.training_args import ParallelMode
|
||||||
from utils import (
|
from utils import (
|
||||||
Seq2SeqDataCollator,
|
Seq2SeqDataCollator,
|
||||||
Seq2SeqDataset,
|
Seq2SeqDataset,
|
||||||
@@ -132,7 +133,7 @@ def main():
|
|||||||
training_args.local_rank,
|
training_args.local_rank,
|
||||||
training_args.device,
|
training_args.device,
|
||||||
training_args.n_gpu,
|
training_args.n_gpu,
|
||||||
bool(training_args.local_rank != -1),
|
bool(training_args.parallel_mode == ParallelMode.DISTRIBUTED),
|
||||||
training_args.fp16,
|
training_args.fp16,
|
||||||
)
|
)
|
||||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from transformers.optimization import (
|
|||||||
get_polynomial_decay_schedule_with_warmup,
|
get_polynomial_decay_schedule_with_warmup,
|
||||||
)
|
)
|
||||||
from transformers.trainer_pt_utils import get_tpu_sampler
|
from transformers.trainer_pt_utils import get_tpu_sampler
|
||||||
|
from transformers.training_args import ParallelMode
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -123,7 +124,7 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
if self.args.sortish_sampler:
|
if self.args.sortish_sampler:
|
||||||
self.train_dataset.make_sortish_sampler(
|
self.train_dataset.make_sortish_sampler(
|
||||||
self.args.per_device_train_batch_size,
|
self.args.per_device_train_batch_size,
|
||||||
distributed=(self.args.local_rank != -1),
|
distributed=(self.args.parallel_mode == ParallelMode.DISTRIBUTED),
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|||||||
Reference in New Issue
Block a user