From d9ece8233d584cdc2eeae5165dd3329328fae328 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Mon, 18 May 2020 10:37:35 -0500 Subject: [PATCH] fix(run_language_modeling): use arg overwrite_cache (#4407) --- examples/language-modeling/run_language_modeling.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/language-modeling/run_language_modeling.py b/examples/language-modeling/run_language_modeling.py index 483d98fad9..d2afa56acb 100644 --- a/examples/language-modeling/run_language_modeling.py +++ b/examples/language-modeling/run_language_modeling.py @@ -120,7 +120,9 @@ def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, eva if args.line_by_line: return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size) else: - return TextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size) + return TextDataset( + tokenizer=tokenizer, file_path=file_path, block_size=args.block_size, overwrite_cache=args.overwrite_cache + ) def main(): @@ -216,6 +218,7 @@ def main(): data_args.block_size = min(data_args.block_size, tokenizer.max_len) # Get datasets + train_dataset = get_dataset(data_args, tokenizer=tokenizer) if training_args.do_train else None eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None data_collator = DataCollatorForLanguageModeling(