From d7929899daaca2d0b910dcef1181ea496f0b4909 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Thu, 14 Nov 2019 10:49:00 -0500 Subject: [PATCH] Specify checkpoint in saved file for run_lm_finetuning.py --- examples/run_lm_finetuning.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/run_lm_finetuning.py b/examples/run_lm_finetuning.py index 2044cfe9e8..a143d55894 100644 --- a/examples/run_lm_finetuning.py +++ b/examples/run_lm_finetuning.py @@ -63,10 +63,10 @@ MODEL_CLASSES = { class TextDataset(Dataset): - def __init__(self, tokenizer, file_path='train', block_size=512): + def __init__(self, tokenizer, args, file_path='train', block_size=512): assert os.path.isfile(file_path) directory, filename = os.path.split(file_path) - cached_features_file = os.path.join(directory, 'cached_lm_' + str(block_size) + '_' + filename) + cached_features_file = os.path.join(directory, args.model_name_or_path + '_cached_lm_' + str(block_size) + '_' + filename) if os.path.exists(cached_features_file): logger.info("Loading features from cached file %s", cached_features_file) @@ -99,7 +99,7 @@ class TextDataset(Dataset): def load_and_cache_examples(args, tokenizer, evaluate=False): - dataset = TextDataset(tokenizer, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size) + dataset = TextDataset(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size) return dataset