Use Filelock to ensure distributed barriers
see context in https://github.com/huggingface/transformers/pull/4223
This commit is contained in:
@@ -118,13 +118,9 @@ class DataTrainingArguments:
|
|||||||
def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False, local_rank=-1):
|
def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False, local_rank=-1):
|
||||||
file_path = args.eval_data_file if evaluate else args.train_data_file
|
file_path = args.eval_data_file if evaluate else args.train_data_file
|
||||||
if args.line_by_line:
|
if args.line_by_line:
|
||||||
return LineByLineTextDataset(
|
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
|
||||||
tokenizer=tokenizer, file_path=file_path, block_size=args.block_size, local_rank=local_rank
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return TextDataset(
|
return TextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
|
||||||
tokenizer=tokenizer, file_path=file_path, block_size=args.block_size, local_rank=local_rank,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
@@ -159,7 +159,6 @@ def main():
|
|||||||
max_seq_length=data_args.max_seq_length,
|
max_seq_length=data_args.max_seq_length,
|
||||||
overwrite_cache=data_args.overwrite_cache,
|
overwrite_cache=data_args.overwrite_cache,
|
||||||
mode=Split.train,
|
mode=Split.train,
|
||||||
local_rank=training_args.local_rank,
|
|
||||||
)
|
)
|
||||||
if training_args.do_train
|
if training_args.do_train
|
||||||
else None
|
else None
|
||||||
@@ -172,7 +171,6 @@ def main():
|
|||||||
max_seq_length=data_args.max_seq_length,
|
max_seq_length=data_args.max_seq_length,
|
||||||
overwrite_cache=data_args.overwrite_cache,
|
overwrite_cache=data_args.overwrite_cache,
|
||||||
mode=Split.dev,
|
mode=Split.dev,
|
||||||
local_rank=training_args.local_rank,
|
|
||||||
)
|
)
|
||||||
if training_args.do_eval
|
if training_args.do_eval
|
||||||
else None
|
else None
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from enum import Enum
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import tqdm
|
import tqdm
|
||||||
|
from filelock import FileLock
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available
|
from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available
|
||||||
|
|
||||||
@@ -77,7 +78,6 @@ class Split(Enum):
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data.dataset import Dataset
|
from torch.utils.data.dataset import Dataset
|
||||||
from transformers import torch_distributed_zero_first
|
|
||||||
|
|
||||||
class MultipleChoiceDataset(Dataset):
|
class MultipleChoiceDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
@@ -95,7 +95,6 @@ if is_torch_available():
|
|||||||
max_seq_length: Optional[int] = None,
|
max_seq_length: Optional[int] = None,
|
||||||
overwrite_cache=False,
|
overwrite_cache=False,
|
||||||
mode: Split = Split.train,
|
mode: Split = Split.train,
|
||||||
local_rank=-1,
|
|
||||||
):
|
):
|
||||||
processor = processors[task]()
|
processor = processors[task]()
|
||||||
|
|
||||||
@@ -103,9 +102,11 @@ if is_torch_available():
|
|||||||
data_dir,
|
data_dir,
|
||||||
"cached_{}_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(max_seq_length), task,),
|
"cached_{}_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(max_seq_length), task,),
|
||||||
)
|
)
|
||||||
with torch_distributed_zero_first(local_rank):
|
|
||||||
# Make sure only the first process in distributed training processes the dataset,
|
# Make sure only the first process in distributed training processes the dataset,
|
||||||
# and the others will use the cache.
|
# and the others will use the cache.
|
||||||
|
lock_path = cached_features_file + ".lock"
|
||||||
|
with FileLock(lock_path):
|
||||||
|
|
||||||
if os.path.exists(cached_features_file) and not overwrite_cache:
|
if os.path.exists(cached_features_file) and not overwrite_cache:
|
||||||
logger.info(f"Loading features from cached file {cached_features_file}")
|
logger.info(f"Loading features from cached file {cached_features_file}")
|
||||||
@@ -130,7 +131,6 @@ if is_torch_available():
|
|||||||
pad_token=tokenizer.pad_token_id,
|
pad_token=tokenizer.pad_token_id,
|
||||||
pad_token_segment_id=tokenizer.pad_token_type_id,
|
pad_token_segment_id=tokenizer.pad_token_type_id,
|
||||||
)
|
)
|
||||||
if local_rank in [-1, 0]:
|
|
||||||
logger.info("Saving features into cached file %s", cached_features_file)
|
logger.info("Saving features into cached file %s", cached_features_file)
|
||||||
torch.save(self.features, cached_features_file)
|
torch.save(self.features, cached_features_file)
|
||||||
|
|
||||||
|
|||||||
@@ -171,7 +171,6 @@ def main():
|
|||||||
max_seq_length=data_args.max_seq_length,
|
max_seq_length=data_args.max_seq_length,
|
||||||
overwrite_cache=data_args.overwrite_cache,
|
overwrite_cache=data_args.overwrite_cache,
|
||||||
mode=Split.train,
|
mode=Split.train,
|
||||||
local_rank=training_args.local_rank,
|
|
||||||
)
|
)
|
||||||
if training_args.do_train
|
if training_args.do_train
|
||||||
else None
|
else None
|
||||||
@@ -185,7 +184,6 @@ def main():
|
|||||||
max_seq_length=data_args.max_seq_length,
|
max_seq_length=data_args.max_seq_length,
|
||||||
overwrite_cache=data_args.overwrite_cache,
|
overwrite_cache=data_args.overwrite_cache,
|
||||||
mode=Split.dev,
|
mode=Split.dev,
|
||||||
local_rank=training_args.local_rank,
|
|
||||||
)
|
)
|
||||||
if training_args.do_eval
|
if training_args.do_eval
|
||||||
else None
|
else None
|
||||||
@@ -261,7 +259,6 @@ def main():
|
|||||||
max_seq_length=data_args.max_seq_length,
|
max_seq_length=data_args.max_seq_length,
|
||||||
overwrite_cache=data_args.overwrite_cache,
|
overwrite_cache=data_args.overwrite_cache,
|
||||||
mode=Split.test,
|
mode=Split.test,
|
||||||
local_rank=training_args.local_rank,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
predictions, label_ids, metrics = trainer.predict(test_dataset)
|
predictions, label_ids, metrics = trainer.predict(test_dataset)
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ from dataclasses import dataclass
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
from filelock import FileLock
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available
|
from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available
|
||||||
|
|
||||||
|
|
||||||
@@ -68,7 +70,6 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data.dataset import Dataset
|
from torch.utils.data.dataset import Dataset
|
||||||
from transformers import torch_distributed_zero_first
|
|
||||||
|
|
||||||
class NerDataset(Dataset):
|
class NerDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
@@ -90,16 +91,16 @@ if is_torch_available():
|
|||||||
max_seq_length: Optional[int] = None,
|
max_seq_length: Optional[int] = None,
|
||||||
overwrite_cache=False,
|
overwrite_cache=False,
|
||||||
mode: Split = Split.train,
|
mode: Split = Split.train,
|
||||||
local_rank=-1,
|
|
||||||
):
|
):
|
||||||
# Load data features from cache or dataset file
|
# Load data features from cache or dataset file
|
||||||
cached_features_file = os.path.join(
|
cached_features_file = os.path.join(
|
||||||
data_dir, "cached_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(max_seq_length)),
|
data_dir, "cached_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(max_seq_length)),
|
||||||
)
|
)
|
||||||
|
|
||||||
with torch_distributed_zero_first(local_rank):
|
|
||||||
# Make sure only the first process in distributed training processes the dataset,
|
# Make sure only the first process in distributed training processes the dataset,
|
||||||
# and the others will use the cache.
|
# and the others will use the cache.
|
||||||
|
lock_path = cached_features_file + ".lock"
|
||||||
|
with FileLock(lock_path):
|
||||||
|
|
||||||
if os.path.exists(cached_features_file) and not overwrite_cache:
|
if os.path.exists(cached_features_file) and not overwrite_cache:
|
||||||
logger.info(f"Loading features from cached file {cached_features_file}")
|
logger.info(f"Loading features from cached file {cached_features_file}")
|
||||||
@@ -125,7 +126,6 @@ if is_torch_available():
|
|||||||
pad_token_segment_id=tokenizer.pad_token_type_id,
|
pad_token_segment_id=tokenizer.pad_token_type_id,
|
||||||
pad_token_label_id=self.pad_token_label_id,
|
pad_token_label_id=self.pad_token_label_id,
|
||||||
)
|
)
|
||||||
if local_rank in [-1, 0]:
|
|
||||||
logger.info(f"Saving features into cached file {cached_features_file}")
|
logger.info(f"Saving features into cached file {cached_features_file}")
|
||||||
torch.save(self.features, cached_features_file)
|
torch.save(self.features, cached_features_file)
|
||||||
|
|
||||||
|
|||||||
@@ -4,10 +4,10 @@ import pickle
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from filelock import FileLock
|
||||||
from torch.utils.data.dataset import Dataset
|
from torch.utils.data.dataset import Dataset
|
||||||
|
|
||||||
from ...tokenization_utils import PreTrainedTokenizer
|
from ...tokenization_utils import PreTrainedTokenizer
|
||||||
from ...trainer import torch_distributed_zero_first
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -20,7 +20,7 @@ class TextDataset(Dataset):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, overwrite_cache=False, local_rank=-1,
|
self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, overwrite_cache=False,
|
||||||
):
|
):
|
||||||
assert os.path.isfile(file_path)
|
assert os.path.isfile(file_path)
|
||||||
|
|
||||||
@@ -31,9 +31,10 @@ class TextDataset(Dataset):
|
|||||||
directory, "cached_lm_{}_{}_{}".format(tokenizer.__class__.__name__, str(block_size), filename,),
|
directory, "cached_lm_{}_{}_{}".format(tokenizer.__class__.__name__, str(block_size), filename,),
|
||||||
)
|
)
|
||||||
|
|
||||||
with torch_distributed_zero_first(local_rank):
|
|
||||||
# Make sure only the first process in distributed training processes the dataset,
|
# Make sure only the first process in distributed training processes the dataset,
|
||||||
# and the others will use the cache.
|
# and the others will use the cache.
|
||||||
|
lock_path = cached_features_file + ".lock"
|
||||||
|
with FileLock(lock_path):
|
||||||
|
|
||||||
if os.path.exists(cached_features_file) and not overwrite_cache:
|
if os.path.exists(cached_features_file) and not overwrite_cache:
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@@ -80,7 +81,7 @@ class LineByLineTextDataset(Dataset):
|
|||||||
soon.
|
soon.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, local_rank=-1):
|
def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int):
|
||||||
assert os.path.isfile(file_path)
|
assert os.path.isfile(file_path)
|
||||||
# Here, we do not cache the features, operating under the assumption
|
# Here, we do not cache the features, operating under the assumption
|
||||||
# that we will soon use fast multithreaded tokenizers from the
|
# that we will soon use fast multithreaded tokenizers from the
|
||||||
|
|||||||
Reference in New Issue
Block a user