From 08ff056c4343c6e118203aac81c3124a6e239973 Mon Sep 17 00:00:00 2001 From: Mayhul Arora Date: Wed, 26 Jun 2019 16:16:12 -0700 Subject: [PATCH] Added option to use multiple workers to create training data for lm fine tuning --- .../pregenerate_training_data.py | 56 ++++++++++++------- 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/examples/lm_finetuning/pregenerate_training_data.py b/examples/lm_finetuning/pregenerate_training_data.py index 8bed1e54d4..d8a241c6a5 100644 --- a/examples/lm_finetuning/pregenerate_training_data.py +++ b/examples/lm_finetuning/pregenerate_training_data.py @@ -3,6 +3,7 @@ from pathlib import Path from tqdm import tqdm, trange from tempfile import TemporaryDirectory import shelve +from multiprocessing import Pool from random import random, randrange, randint, shuffle, choice from pytorch_pretrained_bert.tokenization import BertTokenizer @@ -264,6 +265,28 @@ def create_instances_from_document( return instances +def create_training_file(docs, vocab_list, args, epoch_num): + epoch_filename = args.output_dir / "epoch_{}.json".format(epoch_num) + num_instances = 0 + with epoch_filename.open('w') as epoch_file: + for doc_idx in trange(len(docs), desc="Document"): + doc_instances = create_instances_from_document( + docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob, + masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq, + whole_word_mask=args.do_whole_word_mask, vocab_list=vocab_list) + doc_instances = [json.dumps(instance) for instance in doc_instances] + for instance in doc_instances: + epoch_file.write(instance + '\n') + num_instances += 1 + metrics_file = args.output_dir / "epoch_{}_metrics.json".format(epoch_num) + with metrics_file.open('w') as metrics_file: + metrics = { + "num_training_examples": num_instances, + "max_seq_len": args.max_seq_len + } + metrics_file.write(json.dumps(metrics)) + + def main(): parser = ArgumentParser() parser.add_argument('--train_corpus', type=Path, required=True) @@ -277,6 +300,8 @@ def main(): parser.add_argument("--reduce_memory", action="store_true", help="Reduce memory usage for large datasets by keeping data on disc rather than in memory") + parser.add_argument("--num_workers", type=int, default=1, + help="The number of workers to use to write the files") parser.add_argument("--epochs_to_generate", type=int, default=3, help="Number of epochs of data to pregenerate") parser.add_argument("--max_seq_len", type=int, default=128) @@ -289,6 +314,9 @@ def main(): args = parser.parse_args() + if args.num_workers > 1 and args.reduce_memory: + raise ValueError("Cannot use multiple workers while reducing memory") + tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) vocab_list = list(tokenizer.vocab.keys()) with DocumentDatabase(reduce_memory=args.reduce_memory) as docs: @@ -312,26 +340,14 @@ def main(): "sections or paragraphs.") args.output_dir.mkdir(exist_ok=True) - for epoch in trange(args.epochs_to_generate, desc="Epoch"): - epoch_filename = args.output_dir / f"epoch_{epoch}.json" - num_instances = 0 - with epoch_filename.open('w') as epoch_file: - for doc_idx in trange(len(docs), desc="Document"): - doc_instances = create_instances_from_document( - docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob, - masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq, - whole_word_mask=args.do_whole_word_mask, vocab_list=vocab_list) - doc_instances = [json.dumps(instance) for instance in doc_instances] - for instance in doc_instances: - epoch_file.write(instance + '\n') - num_instances += 1 - metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json" - with metrics_file.open('w') as metrics_file: - metrics = { - "num_training_examples": num_instances, - "max_seq_len": args.max_seq_len - } - metrics_file.write(json.dumps(metrics)) + + if args.num_workers > 1: + writer_workers = Pool(min(args.num_workers, args.epochs_to_generate)) + arguments = [(docs, vocab_list, args, idx) for idx in range(args.epochs_to_generate)] + writer_workers.starmap(create_training_file, arguments) + else: + for epoch in trange(args.epochs_to_generate, desc="Epoch"): + create_training_file(docs, vocab_list, args, epoch) if __name__ == '__main__':