[TPU] Doc, fix xla_spawn.py, only preprocess dataset once (#4223)

* [TPU] Doc, fix xla_spawn.py, only preprocess dataset once

* Update examples/README.md

* [xla_spawn] Add `_mp_fn` to other Trainer scripts

* [TPU] Fix: eval dataloader was None
This commit is contained in:
Julien Chaumond
2020-05-08 14:10:05 -04:00
committed by GitHub
parent 274d850d34
commit 7b75aa9fa5
10 changed files with 88 additions and 47 deletions

View File

@@ -5,12 +5,12 @@ from dataclasses import dataclass, field
from typing import List, Optional
import torch
from filelock import FileLock
from torch.utils.data.dataset import Dataset
from ...tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_xlm_roberta import XLMRobertaTokenizer
from ...trainer import torch_distributed_zero_first
from ..processors.glue import glue_convert_examples_to_features, glue_output_modes, glue_processors
from ..processors.utils import InputFeatures
@@ -63,7 +63,6 @@ class GlueDataset(Dataset):
tokenizer: PreTrainedTokenizer,
limit_length: Optional[int] = None,
evaluate=False,
local_rank=-1,
):
self.args = args
processor = glue_processors[args.task_name]()
@@ -75,9 +74,11 @@ class GlueDataset(Dataset):
"dev" if evaluate else "train", tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name,
),
)
with torch_distributed_zero_first(local_rank):
# Make sure only the first process in distributed training processes the dataset,
# and the others will use the cache.
# Make sure only the first process in distributed training processes the dataset,
# and the others will use the cache.
lock_path = cached_features_file + ".lock"
with FileLock(lock_path):
if os.path.exists(cached_features_file) and not args.overwrite_cache:
start = time.time()
@@ -109,13 +110,12 @@ class GlueDataset(Dataset):
label_list=label_list,
output_mode=self.output_mode,
)
if local_rank in [-1, 0]:
start = time.time()
torch.save(self.features, cached_features_file)
# ^ This seems to take a lot of time so I want to investigate why and how we can improve.
logger.info(
f"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
)
start = time.time()
torch.save(self.features, cached_features_file)
# ^ This seems to take a lot of time so I want to investigate why and how we can improve.
logger.info(
f"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
)
def __len__(self):
return len(self.features)

View File

@@ -6,7 +6,7 @@ import re
import shutil
from contextlib import contextmanager
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
@@ -195,10 +195,12 @@ class Trainer:
if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
sampler = get_tpu_sampler(eval_dataset) if is_tpu_available() else None
data_loader = DataLoader(
eval_dataset if eval_dataset is not None else self.eval_dataset,
eval_dataset,
sampler=sampler,
batch_size=self.args.eval_batch_size,
shuffle=False,
@@ -267,6 +269,16 @@ class Trainer:
# keep track of model topology and gradients
wandb.watch(self.model)
def num_examples(self, dataloader: Union[DataLoader, "pl.PerDeviceLoader"]) -> int:
"""
Helper to get num of examples from a DataLoader, by accessing its Dataset.
"""
if is_tpu_available():
assert isinstance(dataloader, pl.PerDeviceLoader)
return len(dataloader._loader._loader.dataset)
else:
return len(dataloader.dataset)
def train(self, model_path: Optional[str] = None):
"""
Main training entry point.
@@ -326,17 +338,15 @@ class Trainer:
# Train!
if is_tpu_available():
num_examples = len(train_dataloader._loader._loader.dataset)
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
else:
num_examples = len(train_dataloader.dataset)
total_train_batch_size = (
self.args.train_batch_size
* self.args.gradient_accumulation_steps
* (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
)
logger.info("***** Running training *****")
logger.info(" Num examples = %d", num_examples)
logger.info(" Num examples = %d", self.num_examples(train_dataloader))
logger.info(" Num Epochs = %d", num_train_epochs)
logger.info(" Instantaneous batch size per device = %d", self.args.per_gpu_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
@@ -606,9 +616,13 @@ class Trainer:
model = self.model
model.to(self.args.device)
if is_tpu_available():
batch_size = dataloader._loader._loader.batch_size
else:
batch_size = dataloader.batch_size
logger.info("***** Running %s *****", description)
logger.info(" Num examples = %d", len(dataloader.dataset))
logger.info(" Batch size = %d", dataloader.batch_size)
logger.info(" Num examples = %d", self.num_examples(dataloader))
logger.info(" Batch size = %d", batch_size)
eval_losses: List[float] = []
preds: np.ndarray = None
label_ids: np.ndarray = None