Merge pull request #733 from ceremonious/parallel-generation

Added option to use multiple workers to create training data
This commit is contained in:
Thomas Wolf
2019-07-05 12:04:30 +02:00
committed by GitHub

View File

@@ -3,6 +3,7 @@ from pathlib import Path
from tqdm import tqdm, trange from tqdm import tqdm, trange
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import shelve import shelve
from multiprocessing import Pool
from random import random, randrange, randint, shuffle, choice from random import random, randrange, randint, shuffle, choice
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization import BertTokenizer
@@ -264,6 +265,28 @@ def create_instances_from_document(
return instances return instances
def create_training_file(docs, vocab_list, args, epoch_num):
epoch_filename = args.output_dir / "epoch_{}.json".format(epoch_num)
num_instances = 0
with epoch_filename.open('w') as epoch_file:
for doc_idx in trange(len(docs), desc="Document"):
doc_instances = create_instances_from_document(
docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob,
masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq,
whole_word_mask=args.do_whole_word_mask, vocab_list=vocab_list)
doc_instances = [json.dumps(instance) for instance in doc_instances]
for instance in doc_instances:
epoch_file.write(instance + '\n')
num_instances += 1
metrics_file = args.output_dir / "epoch_{}_metrics.json".format(epoch_num)
with metrics_file.open('w') as metrics_file:
metrics = {
"num_training_examples": num_instances,
"max_seq_len": args.max_seq_len
}
metrics_file.write(json.dumps(metrics))
def main(): def main():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument('--train_corpus', type=Path, required=True) parser.add_argument('--train_corpus', type=Path, required=True)
@@ -277,6 +300,8 @@ def main():
parser.add_argument("--reduce_memory", action="store_true", parser.add_argument("--reduce_memory", action="store_true",
help="Reduce memory usage for large datasets by keeping data on disc rather than in memory") help="Reduce memory usage for large datasets by keeping data on disc rather than in memory")
parser.add_argument("--num_workers", type=int, default=1,
help="The number of workers to use to write the files")
parser.add_argument("--epochs_to_generate", type=int, default=3, parser.add_argument("--epochs_to_generate", type=int, default=3,
help="Number of epochs of data to pregenerate") help="Number of epochs of data to pregenerate")
parser.add_argument("--max_seq_len", type=int, default=128) parser.add_argument("--max_seq_len", type=int, default=128)
@@ -289,6 +314,9 @@ def main():
args = parser.parse_args() args = parser.parse_args()
if args.num_workers > 1 and args.reduce_memory:
raise ValueError("Cannot use multiple workers while reducing memory")
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
vocab_list = list(tokenizer.vocab.keys()) vocab_list = list(tokenizer.vocab.keys())
with DocumentDatabase(reduce_memory=args.reduce_memory) as docs: with DocumentDatabase(reduce_memory=args.reduce_memory) as docs:
@@ -312,26 +340,14 @@ def main():
"sections or paragraphs.") "sections or paragraphs.")
args.output_dir.mkdir(exist_ok=True) args.output_dir.mkdir(exist_ok=True)
for epoch in trange(args.epochs_to_generate, desc="Epoch"):
epoch_filename = args.output_dir / f"epoch_{epoch}.json" if args.num_workers > 1:
num_instances = 0 writer_workers = Pool(min(args.num_workers, args.epochs_to_generate))
with epoch_filename.open('w') as epoch_file: arguments = [(docs, vocab_list, args, idx) for idx in range(args.epochs_to_generate)]
for doc_idx in trange(len(docs), desc="Document"): writer_workers.starmap(create_training_file, arguments)
doc_instances = create_instances_from_document( else:
docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob, for epoch in trange(args.epochs_to_generate, desc="Epoch"):
masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq, create_training_file(docs, vocab_list, args, epoch)
whole_word_mask=args.do_whole_word_mask, vocab_list=vocab_list)
doc_instances = [json.dumps(instance) for instance in doc_instances]
for instance in doc_instances:
epoch_file.write(instance + '\n')
num_instances += 1
metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json"
with metrics_file.open('w') as metrics_file:
metrics = {
"num_training_examples": num_instances,
"max_seq_len": args.max_seq_len
}
metrics_file.write(json.dumps(metrics))
if __name__ == '__main__': if __name__ == '__main__':