[skip ci] remove local rank

This commit is contained in:
Julien Chaumond
2020-05-15 17:08:38 -04:00
parent 62427d0815
commit 15550ce0d1
2 changed files with 4 additions and 12 deletions

View File

@@ -115,7 +115,7 @@ class DataTrainingArguments:
)
def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False, local_rank=-1):
def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False):
file_path = args.eval_data_file if evaluate else args.train_data_file
if args.line_by_line:
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
@@ -216,16 +216,8 @@ def main():
data_args.block_size = min(data_args.block_size, tokenizer.max_len)
# Get datasets
train_dataset = (
get_dataset(data_args, tokenizer=tokenizer, local_rank=training_args.local_rank)
if training_args.do_train
else None
)
eval_dataset = (
get_dataset(data_args, tokenizer=tokenizer, local_rank=training_args.local_rank, evaluate=True)
if training_args.do_eval
else None
)
train_dataset = get_dataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
)