Use Filelock to ensure distributed barriers

see context in https://github.com/huggingface/transformers/pull/4223
This commit is contained in:
Julien Chaumond
2020-05-14 11:58:32 -04:00
parent 015f7812ed
commit c547f15a17
6 changed files with 25 additions and 33 deletions

View File

@@ -4,10 +4,10 @@ import pickle
import time
import torch
from filelock import FileLock
from torch.utils.data.dataset import Dataset
from ...tokenization_utils import PreTrainedTokenizer
from ...trainer import torch_distributed_zero_first
logger = logging.getLogger(__name__)
@@ -20,7 +20,7 @@ class TextDataset(Dataset):
"""
def __init__(
self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, overwrite_cache=False, local_rank=-1,
self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, overwrite_cache=False,
):
assert os.path.isfile(file_path)
@@ -31,9 +31,10 @@ class TextDataset(Dataset):
directory, "cached_lm_{}_{}_{}".format(tokenizer.__class__.__name__, str(block_size), filename,),
)
with torch_distributed_zero_first(local_rank):
# Make sure only the first process in distributed training processes the dataset,
# and the others will use the cache.
# Make sure only the first process in distributed training processes the dataset,
# and the others will use the cache.
lock_path = cached_features_file + ".lock"
with FileLock(lock_path):
if os.path.exists(cached_features_file) and not overwrite_cache:
start = time.time()
@@ -80,7 +81,7 @@ class LineByLineTextDataset(Dataset):
soon.
"""
def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, local_rank=-1):
def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int):
assert os.path.isfile(file_path)
# Here, we do not cache the features, operating under the assumption
# that we will soon use fast multithreaded tokenizers from the