[trainer] add main_process_first context manager (#12351)
* main_process_first context manager * handle multi-node, add context description * sync desc
This commit is contained in:
@@ -428,6 +428,7 @@ def main():
|
||||
train_dataset = raw_datasets["train"]
|
||||
if data_args.max_train_samples is not None:
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
with training_args.main_process_first(desc="train dataset map pre-processing"):
|
||||
train_dataset = train_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
@@ -444,6 +445,7 @@ def main():
|
||||
eval_dataset = raw_datasets["validation"]
|
||||
if data_args.max_eval_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
with training_args.main_process_first(desc="validation dataset map pre-processing"):
|
||||
eval_dataset = eval_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
@@ -460,6 +462,7 @@ def main():
|
||||
predict_dataset = raw_datasets["test"]
|
||||
if data_args.max_predict_samples is not None:
|
||||
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||
with training_args.main_process_first(desc="prediction dataset map pre-processing"):
|
||||
predict_dataset = predict_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
@@ -968,6 +969,49 @@ class TrainingArguments:
|
||||
"""
|
||||
return not (self.deepspeed or is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled())
|
||||
|
||||
@contextlib.contextmanager
|
||||
def main_process_first(self, local=True, desc="work"):
|
||||
"""
|
||||
A context manager for torch distributed environment where on needs to do something on the main process,
|
||||
while blocking replicas, and when it's finished releasing the replicas.
|
||||
|
||||
One such use is for ``datasets``'s ``map`` feature which to be efficient should be run once on the main
|
||||
process, which upon completion saves a cached version of results and which then automatically gets loaded
|
||||
by the replicas.
|
||||
|
||||
Args:
|
||||
local (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
if :obj:`True` first means process of rank 0 of each node if :obj:`False` first means process of rank 0
|
||||
of node rank 0 In multi-node environment with a shared filesystem you most likely will want to use
|
||||
``local=False`` so that only the main process of the first node will do the processing. If however, the
|
||||
filesystem is not shared, then the main process of each node will need to do the processing, which is
|
||||
the default behavior.
|
||||
desc (:obj:`str`, `optional`, defaults to ``"work"``):
|
||||
a work description to be used in debug logs
|
||||
|
||||
"""
|
||||
if is_torch_available() and self.world_size > 1:
|
||||
if local:
|
||||
is_main_process = self.local_process_index == 0
|
||||
main_process_desc = "main local process"
|
||||
else:
|
||||
is_main_process = self.process_index == 0
|
||||
main_process_desc = "main process"
|
||||
|
||||
try:
|
||||
if not is_main_process:
|
||||
# tell all replicas to wait
|
||||
logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}")
|
||||
torch.distributed.barrier()
|
||||
yield
|
||||
finally:
|
||||
if is_main_process:
|
||||
# the wait is over
|
||||
logger.debug(f"{self.process_index}: {main_process_desc} completed {desc}, releasing all replicas")
|
||||
torch.distributed.barrier()
|
||||
else:
|
||||
yield
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance while replace `Enum` by their values (for JSON serialization support).
|
||||
|
||||
Reference in New Issue
Block a user