Use Filelock to ensure distributed barriers
see context in https://github.com/huggingface/transformers/pull/4223
This commit is contained in:
@@ -4,10 +4,10 @@ import pickle
|
||||
import time
|
||||
|
||||
import torch
|
||||
from filelock import FileLock
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...trainer import torch_distributed_zero_first
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -20,7 +20,7 @@ class TextDataset(Dataset):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, overwrite_cache=False, local_rank=-1,
|
||||
self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, overwrite_cache=False,
|
||||
):
|
||||
assert os.path.isfile(file_path)
|
||||
|
||||
@@ -31,9 +31,10 @@ class TextDataset(Dataset):
|
||||
directory, "cached_lm_{}_{}_{}".format(tokenizer.__class__.__name__, str(block_size), filename,),
|
||||
)
|
||||
|
||||
with torch_distributed_zero_first(local_rank):
|
||||
# Make sure only the first process in distributed training processes the dataset,
|
||||
# and the others will use the cache.
|
||||
# Make sure only the first process in distributed training processes the dataset,
|
||||
# and the others will use the cache.
|
||||
lock_path = cached_features_file + ".lock"
|
||||
with FileLock(lock_path):
|
||||
|
||||
if os.path.exists(cached_features_file) and not overwrite_cache:
|
||||
start = time.time()
|
||||
@@ -80,7 +81,7 @@ class LineByLineTextDataset(Dataset):
|
||||
soon.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, local_rank=-1):
|
||||
def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int):
|
||||
assert os.path.isfile(file_path)
|
||||
# Here, we do not cache the features, operating under the assumption
|
||||
# that we will soon use fast multithreaded tokenizers from the
|
||||
|
||||
Reference in New Issue
Block a user