From 64e6098094d063687f90d3bf49bdc7571551c344 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Fri, 25 Jun 2021 14:58:03 -0700 Subject: [PATCH] [trainer] add main_process_first context manager (#12351) * main_process_first context manager * handle multi-node, add context description * sync desc --- .../pytorch/translation/run_translation.py | 51 ++++++++++--------- src/transformers/training_args.py | 44 ++++++++++++++++ 2 files changed, 71 insertions(+), 24 deletions(-) diff --git a/examples/pytorch/translation/run_translation.py b/examples/pytorch/translation/run_translation.py index 680ab4fd50..b41386f0fe 100755 --- a/examples/pytorch/translation/run_translation.py +++ b/examples/pytorch/translation/run_translation.py @@ -428,14 +428,15 @@ def main(): train_dataset = raw_datasets["train"] if data_args.max_train_samples is not None: train_dataset = train_dataset.select(range(data_args.max_train_samples)) - train_dataset = train_dataset.map( - preprocess_function, - batched=True, - num_proc=data_args.preprocessing_num_workers, - remove_columns=column_names, - load_from_cache_file=not data_args.overwrite_cache, - desc="Running tokenizer on train dataset", - ) + with training_args.main_process_first(desc="train dataset map pre-processing"): + train_dataset = train_dataset.map( + preprocess_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on train dataset", + ) if training_args.do_eval: max_target_length = data_args.val_max_target_length @@ -444,14 +445,15 @@ def main(): eval_dataset = raw_datasets["validation"] if data_args.max_eval_samples is not None: eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) - eval_dataset = eval_dataset.map( - preprocess_function, - batched=True, - num_proc=data_args.preprocessing_num_workers, - remove_columns=column_names, - load_from_cache_file=not data_args.overwrite_cache, - desc="Running tokenizer on validation dataset", - ) + with training_args.main_process_first(desc="validation dataset map pre-processing"): + eval_dataset = eval_dataset.map( + preprocess_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on validation dataset", + ) if training_args.do_predict: max_target_length = data_args.val_max_target_length @@ -460,14 +462,15 @@ def main(): predict_dataset = raw_datasets["test"] if data_args.max_predict_samples is not None: predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) - predict_dataset = predict_dataset.map( - preprocess_function, - batched=True, - num_proc=data_args.preprocessing_num_workers, - remove_columns=column_names, - load_from_cache_file=not data_args.overwrite_cache, - desc="Running tokenizer on prediction dataset", - ) + with training_args.main_process_first(desc="prediction dataset map pre-processing"): + predict_dataset = predict_dataset.map( + preprocess_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on prediction dataset", + ) # Data collator label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index a2bd83a4b1..024fac6ec8 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import json import os import warnings @@ -968,6 +969,49 @@ class TrainingArguments: """ return not (self.deepspeed or is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled()) + @contextlib.contextmanager + def main_process_first(self, local=True, desc="work"): + """ + A context manager for torch distributed environment where on needs to do something on the main process, + while blocking replicas, and when it's finished releasing the replicas. + + One such use is for ``datasets``'s ``map`` feature which to be efficient should be run once on the main + process, which upon completion saves a cached version of results and which then automatically gets loaded + by the replicas. + + Args: + local (:obj:`bool`, `optional`, defaults to :obj:`True`): + if :obj:`True` first means process of rank 0 of each node if :obj:`False` first means process of rank 0 + of node rank 0 In multi-node environment with a shared filesystem you most likely will want to use + ``local=False`` so that only the main process of the first node will do the processing. If however, the + filesystem is not shared, then the main process of each node will need to do the processing, which is + the default behavior. + desc (:obj:`str`, `optional`, defaults to ``"work"``): + a work description to be used in debug logs + + """ + if is_torch_available() and self.world_size > 1: + if local: + is_main_process = self.local_process_index == 0 + main_process_desc = "main local process" + else: + is_main_process = self.process_index == 0 + main_process_desc = "main process" + + try: + if not is_main_process: + # tell all replicas to wait + logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}") + torch.distributed.barrier() + yield + finally: + if is_main_process: + # the wait is over + logger.debug(f"{self.process_index}: {main_process_desc} completed {desc}, releasing all replicas") + torch.distributed.barrier() + else: + yield + def to_dict(self): """ Serializes this instance while replace `Enum` by their values (for JSON serialization support).