[skip ci] remove local rank
This commit is contained in:
@@ -115,7 +115,7 @@ class DataTrainingArguments:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False, local_rank=-1):
|
def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False):
|
||||||
file_path = args.eval_data_file if evaluate else args.train_data_file
|
file_path = args.eval_data_file if evaluate else args.train_data_file
|
||||||
if args.line_by_line:
|
if args.line_by_line:
|
||||||
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
|
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
|
||||||
@@ -216,16 +216,8 @@ def main():
|
|||||||
data_args.block_size = min(data_args.block_size, tokenizer.max_len)
|
data_args.block_size = min(data_args.block_size, tokenizer.max_len)
|
||||||
|
|
||||||
# Get datasets
|
# Get datasets
|
||||||
train_dataset = (
|
train_dataset = get_dataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
|
||||||
get_dataset(data_args, tokenizer=tokenizer, local_rank=training_args.local_rank)
|
eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
|
||||||
if training_args.do_train
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
eval_dataset = (
|
|
||||||
get_dataset(data_args, tokenizer=tokenizer, local_rank=training_args.local_rank, evaluate=True)
|
|
||||||
if training_args.do_eval
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
data_collator = DataCollatorForLanguageModeling(
|
data_collator = DataCollatorForLanguageModeling(
|
||||||
tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
|
tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -68,6 +68,6 @@ class RobertaConfig(BertConfig):
|
|||||||
model_type = "roberta"
|
model_type = "roberta"
|
||||||
|
|
||||||
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2, **kwargs):
|
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2, **kwargs):
|
||||||
"""Constructs FlaubertConfig.
|
"""Constructs RobertaConfig.
|
||||||
"""
|
"""
|
||||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user