[TPU] Doc, fix xla_spawn.py, only preprocess dataset once (#4223)

* [TPU] Doc, fix xla_spawn.py, only preprocess dataset once

* Update examples/README.md

* [xla_spawn] Add `_mp_fn` to other Trainer scripts

* [TPU] Fix: eval dataloader was None
This commit is contained in:
Julien Chaumond
2020-05-08 14:10:05 -04:00
committed by GitHub
parent 274d850d34
commit 7b75aa9fa5
10 changed files with 88 additions and 47 deletions

View File

@@ -5,12 +5,12 @@ from dataclasses import dataclass, field
from typing import List, Optional
import torch
from filelock import FileLock
from torch.utils.data.dataset import Dataset
from ...tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_xlm_roberta import XLMRobertaTokenizer
from ...trainer import torch_distributed_zero_first
from ..processors.glue import glue_convert_examples_to_features, glue_output_modes, glue_processors
from ..processors.utils import InputFeatures
@@ -63,7 +63,6 @@ class GlueDataset(Dataset):
tokenizer: PreTrainedTokenizer,
limit_length: Optional[int] = None,
evaluate=False,
local_rank=-1,
):
self.args = args
processor = glue_processors[args.task_name]()
@@ -75,9 +74,11 @@ class GlueDataset(Dataset):
"dev" if evaluate else "train", tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name,
),
)
with torch_distributed_zero_first(local_rank):
# Make sure only the first process in distributed training processes the dataset,
# and the others will use the cache.
# Make sure only the first process in distributed training processes the dataset,
# and the others will use the cache.
lock_path = cached_features_file + ".lock"
with FileLock(lock_path):
if os.path.exists(cached_features_file) and not args.overwrite_cache:
start = time.time()
@@ -109,13 +110,12 @@ class GlueDataset(Dataset):
label_list=label_list,
output_mode=self.output_mode,
)
if local_rank in [-1, 0]:
start = time.time()
torch.save(self.features, cached_features_file)
# ^ This seems to take a lot of time so I want to investigate why and how we can improve.
logger.info(
f"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
)
start = time.time()
torch.save(self.features, cached_features_file)
# ^ This seems to take a lot of time so I want to investigate why and how we can improve.
logger.info(
f"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
)
def __len__(self):
return len(self.features)