[TPU] Doc, fix xla_spawn.py, only preprocess dataset once (#4223)
* [TPU] Doc, fix xla_spawn.py, only preprocess dataset once * Update examples/README.md * [xla_spawn] Add `_mp_fn` to other Trainer scripts * [TPU] Fix: eval dataloader was None
This commit is contained in:
@@ -12,17 +12,13 @@ Inspired by https://github.com/pytorch/pytorch/blob/master/torch/distributed/lau
|
||||
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from argparse import REMAINDER, ArgumentParser
|
||||
from pathlib import Path
|
||||
|
||||
import torch_xla.distributed.xla_multiprocessing as xmp
|
||||
|
||||
|
||||
def trim_suffix(s: str, suffix: str):
|
||||
return s if not s.endswith(suffix) or len(suffix) == 0 else s[: -len(suffix)]
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
Helper function parsing the command line options
|
||||
@@ -44,7 +40,7 @@ def parse_args():
|
||||
"training_script",
|
||||
type=str,
|
||||
help=(
|
||||
"The full module name to the single TPU training "
|
||||
"The full path to the single TPU training "
|
||||
"program/script to be launched in parallel, "
|
||||
"followed by all the arguments for the "
|
||||
"training script"
|
||||
@@ -61,7 +57,9 @@ def main():
|
||||
args = parse_args()
|
||||
|
||||
# Import training_script as a module.
|
||||
mod_name = trim_suffix(os.path.basename(args.training_script), ".py")
|
||||
script_fpath = Path(args.training_script)
|
||||
sys.path.append(str(script_fpath.parent.resolve()))
|
||||
mod_name = script_fpath.stem
|
||||
mod = importlib.import_module(mod_name)
|
||||
|
||||
# Patch sys.argv
|
||||
|
||||
Reference in New Issue
Block a user