Use Filelock to ensure distributed barriers
see context in https://github.com/huggingface/transformers/pull/4223
This commit is contained in:
@@ -171,7 +171,6 @@ def main():
|
||||
max_seq_length=data_args.max_seq_length,
|
||||
overwrite_cache=data_args.overwrite_cache,
|
||||
mode=Split.train,
|
||||
local_rank=training_args.local_rank,
|
||||
)
|
||||
if training_args.do_train
|
||||
else None
|
||||
@@ -185,7 +184,6 @@ def main():
|
||||
max_seq_length=data_args.max_seq_length,
|
||||
overwrite_cache=data_args.overwrite_cache,
|
||||
mode=Split.dev,
|
||||
local_rank=training_args.local_rank,
|
||||
)
|
||||
if training_args.do_eval
|
||||
else None
|
||||
@@ -261,7 +259,6 @@ def main():
|
||||
max_seq_length=data_args.max_seq_length,
|
||||
overwrite_cache=data_args.overwrite_cache,
|
||||
mode=Split.test,
|
||||
local_rank=training_args.local_rank,
|
||||
)
|
||||
|
||||
predictions, label_ids, metrics = trainer.predict(test_dataset)
|
||||
|
||||
@@ -22,6 +22,8 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from filelock import FileLock
|
||||
|
||||
from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available
|
||||
|
||||
|
||||
@@ -68,7 +70,6 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from transformers import torch_distributed_zero_first
|
||||
|
||||
class NerDataset(Dataset):
|
||||
"""
|
||||
@@ -90,16 +91,16 @@ if is_torch_available():
|
||||
max_seq_length: Optional[int] = None,
|
||||
overwrite_cache=False,
|
||||
mode: Split = Split.train,
|
||||
local_rank=-1,
|
||||
):
|
||||
# Load data features from cache or dataset file
|
||||
cached_features_file = os.path.join(
|
||||
data_dir, "cached_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(max_seq_length)),
|
||||
)
|
||||
|
||||
with torch_distributed_zero_first(local_rank):
|
||||
# Make sure only the first process in distributed training processes the dataset,
|
||||
# and the others will use the cache.
|
||||
# Make sure only the first process in distributed training processes the dataset,
|
||||
# and the others will use the cache.
|
||||
lock_path = cached_features_file + ".lock"
|
||||
with FileLock(lock_path):
|
||||
|
||||
if os.path.exists(cached_features_file) and not overwrite_cache:
|
||||
logger.info(f"Loading features from cached file {cached_features_file}")
|
||||
@@ -125,9 +126,8 @@ if is_torch_available():
|
||||
pad_token_segment_id=tokenizer.pad_token_type_id,
|
||||
pad_token_label_id=self.pad_token_label_id,
|
||||
)
|
||||
if local_rank in [-1, 0]:
|
||||
logger.info(f"Saving features into cached file {cached_features_file}")
|
||||
torch.save(self.features, cached_features_file)
|
||||
logger.info(f"Saving features into cached file {cached_features_file}")
|
||||
torch.save(self.features, cached_features_file)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.features)
|
||||
|
||||
Reference in New Issue
Block a user