[trainer] add main_process_first context manager (#12351)
* main_process_first context manager * handle multi-node, add context description * sync desc
This commit is contained in:
@@ -428,14 +428,15 @@ def main():
|
|||||||
train_dataset = raw_datasets["train"]
|
train_dataset = raw_datasets["train"]
|
||||||
if data_args.max_train_samples is not None:
|
if data_args.max_train_samples is not None:
|
||||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||||
train_dataset = train_dataset.map(
|
with training_args.main_process_first(desc="train dataset map pre-processing"):
|
||||||
preprocess_function,
|
train_dataset = train_dataset.map(
|
||||||
batched=True,
|
preprocess_function,
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
batched=True,
|
||||||
remove_columns=column_names,
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
load_from_cache_file=not data_args.overwrite_cache,
|
remove_columns=column_names,
|
||||||
desc="Running tokenizer on train dataset",
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
)
|
desc="Running tokenizer on train dataset",
|
||||||
|
)
|
||||||
|
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
max_target_length = data_args.val_max_target_length
|
max_target_length = data_args.val_max_target_length
|
||||||
@@ -444,14 +445,15 @@ def main():
|
|||||||
eval_dataset = raw_datasets["validation"]
|
eval_dataset = raw_datasets["validation"]
|
||||||
if data_args.max_eval_samples is not None:
|
if data_args.max_eval_samples is not None:
|
||||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||||
eval_dataset = eval_dataset.map(
|
with training_args.main_process_first(desc="validation dataset map pre-processing"):
|
||||||
preprocess_function,
|
eval_dataset = eval_dataset.map(
|
||||||
batched=True,
|
preprocess_function,
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
batched=True,
|
||||||
remove_columns=column_names,
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
load_from_cache_file=not data_args.overwrite_cache,
|
remove_columns=column_names,
|
||||||
desc="Running tokenizer on validation dataset",
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
)
|
desc="Running tokenizer on validation dataset",
|
||||||
|
)
|
||||||
|
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
max_target_length = data_args.val_max_target_length
|
max_target_length = data_args.val_max_target_length
|
||||||
@@ -460,14 +462,15 @@ def main():
|
|||||||
predict_dataset = raw_datasets["test"]
|
predict_dataset = raw_datasets["test"]
|
||||||
if data_args.max_predict_samples is not None:
|
if data_args.max_predict_samples is not None:
|
||||||
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||||
predict_dataset = predict_dataset.map(
|
with training_args.main_process_first(desc="prediction dataset map pre-processing"):
|
||||||
preprocess_function,
|
predict_dataset = predict_dataset.map(
|
||||||
batched=True,
|
preprocess_function,
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
batched=True,
|
||||||
remove_columns=column_names,
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
load_from_cache_file=not data_args.overwrite_cache,
|
remove_columns=column_names,
|
||||||
desc="Running tokenizer on prediction dataset",
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
)
|
desc="Running tokenizer on prediction dataset",
|
||||||
|
)
|
||||||
|
|
||||||
# Data collator
|
# Data collator
|
||||||
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
@@ -968,6 +969,49 @@ class TrainingArguments:
|
|||||||
"""
|
"""
|
||||||
return not (self.deepspeed or is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled())
|
return not (self.deepspeed or is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled())
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def main_process_first(self, local=True, desc="work"):
|
||||||
|
"""
|
||||||
|
A context manager for torch distributed environment where on needs to do something on the main process,
|
||||||
|
while blocking replicas, and when it's finished releasing the replicas.
|
||||||
|
|
||||||
|
One such use is for ``datasets``'s ``map`` feature which to be efficient should be run once on the main
|
||||||
|
process, which upon completion saves a cached version of results and which then automatically gets loaded
|
||||||
|
by the replicas.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
if :obj:`True` first means process of rank 0 of each node if :obj:`False` first means process of rank 0
|
||||||
|
of node rank 0 In multi-node environment with a shared filesystem you most likely will want to use
|
||||||
|
``local=False`` so that only the main process of the first node will do the processing. If however, the
|
||||||
|
filesystem is not shared, then the main process of each node will need to do the processing, which is
|
||||||
|
the default behavior.
|
||||||
|
desc (:obj:`str`, `optional`, defaults to ``"work"``):
|
||||||
|
a work description to be used in debug logs
|
||||||
|
|
||||||
|
"""
|
||||||
|
if is_torch_available() and self.world_size > 1:
|
||||||
|
if local:
|
||||||
|
is_main_process = self.local_process_index == 0
|
||||||
|
main_process_desc = "main local process"
|
||||||
|
else:
|
||||||
|
is_main_process = self.process_index == 0
|
||||||
|
main_process_desc = "main process"
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_main_process:
|
||||||
|
# tell all replicas to wait
|
||||||
|
logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}")
|
||||||
|
torch.distributed.barrier()
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if is_main_process:
|
||||||
|
# the wait is over
|
||||||
|
logger.debug(f"{self.process_index}: {main_process_desc} completed {desc}, releasing all replicas")
|
||||||
|
torch.distributed.barrier()
|
||||||
|
else:
|
||||||
|
yield
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
"""
|
"""
|
||||||
Serializes this instance while replace `Enum` by their values (for JSON serialization support).
|
Serializes this instance while replace `Enum` by their values (for JSON serialization support).
|
||||||
|
|||||||
Reference in New Issue
Block a user