Merge pull request #733 from ceremonious/parallel-generation
Added option to use multiple workers to create training data
This commit is contained in:
@@ -3,6 +3,7 @@ from pathlib import Path
|
|||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
import shelve
|
import shelve
|
||||||
|
from multiprocessing import Pool
|
||||||
|
|
||||||
from random import random, randrange, randint, shuffle, choice
|
from random import random, randrange, randint, shuffle, choice
|
||||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||||
@@ -264,6 +265,28 @@ def create_instances_from_document(
|
|||||||
return instances
|
return instances
|
||||||
|
|
||||||
|
|
||||||
|
def create_training_file(docs, vocab_list, args, epoch_num):
|
||||||
|
epoch_filename = args.output_dir / "epoch_{}.json".format(epoch_num)
|
||||||
|
num_instances = 0
|
||||||
|
with epoch_filename.open('w') as epoch_file:
|
||||||
|
for doc_idx in trange(len(docs), desc="Document"):
|
||||||
|
doc_instances = create_instances_from_document(
|
||||||
|
docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob,
|
||||||
|
masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq,
|
||||||
|
whole_word_mask=args.do_whole_word_mask, vocab_list=vocab_list)
|
||||||
|
doc_instances = [json.dumps(instance) for instance in doc_instances]
|
||||||
|
for instance in doc_instances:
|
||||||
|
epoch_file.write(instance + '\n')
|
||||||
|
num_instances += 1
|
||||||
|
metrics_file = args.output_dir / "epoch_{}_metrics.json".format(epoch_num)
|
||||||
|
with metrics_file.open('w') as metrics_file:
|
||||||
|
metrics = {
|
||||||
|
"num_training_examples": num_instances,
|
||||||
|
"max_seq_len": args.max_seq_len
|
||||||
|
}
|
||||||
|
metrics_file.write(json.dumps(metrics))
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument('--train_corpus', type=Path, required=True)
|
parser.add_argument('--train_corpus', type=Path, required=True)
|
||||||
@@ -277,6 +300,8 @@ def main():
|
|||||||
parser.add_argument("--reduce_memory", action="store_true",
|
parser.add_argument("--reduce_memory", action="store_true",
|
||||||
help="Reduce memory usage for large datasets by keeping data on disc rather than in memory")
|
help="Reduce memory usage for large datasets by keeping data on disc rather than in memory")
|
||||||
|
|
||||||
|
parser.add_argument("--num_workers", type=int, default=1,
|
||||||
|
help="The number of workers to use to write the files")
|
||||||
parser.add_argument("--epochs_to_generate", type=int, default=3,
|
parser.add_argument("--epochs_to_generate", type=int, default=3,
|
||||||
help="Number of epochs of data to pregenerate")
|
help="Number of epochs of data to pregenerate")
|
||||||
parser.add_argument("--max_seq_len", type=int, default=128)
|
parser.add_argument("--max_seq_len", type=int, default=128)
|
||||||
@@ -289,6 +314,9 @@ def main():
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.num_workers > 1 and args.reduce_memory:
|
||||||
|
raise ValueError("Cannot use multiple workers while reducing memory")
|
||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||||
vocab_list = list(tokenizer.vocab.keys())
|
vocab_list = list(tokenizer.vocab.keys())
|
||||||
with DocumentDatabase(reduce_memory=args.reduce_memory) as docs:
|
with DocumentDatabase(reduce_memory=args.reduce_memory) as docs:
|
||||||
@@ -312,26 +340,14 @@ def main():
|
|||||||
"sections or paragraphs.")
|
"sections or paragraphs.")
|
||||||
|
|
||||||
args.output_dir.mkdir(exist_ok=True)
|
args.output_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
if args.num_workers > 1:
|
||||||
|
writer_workers = Pool(min(args.num_workers, args.epochs_to_generate))
|
||||||
|
arguments = [(docs, vocab_list, args, idx) for idx in range(args.epochs_to_generate)]
|
||||||
|
writer_workers.starmap(create_training_file, arguments)
|
||||||
|
else:
|
||||||
for epoch in trange(args.epochs_to_generate, desc="Epoch"):
|
for epoch in trange(args.epochs_to_generate, desc="Epoch"):
|
||||||
epoch_filename = args.output_dir / f"epoch_{epoch}.json"
|
create_training_file(docs, vocab_list, args, epoch)
|
||||||
num_instances = 0
|
|
||||||
with epoch_filename.open('w') as epoch_file:
|
|
||||||
for doc_idx in trange(len(docs), desc="Document"):
|
|
||||||
doc_instances = create_instances_from_document(
|
|
||||||
docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob,
|
|
||||||
masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq,
|
|
||||||
whole_word_mask=args.do_whole_word_mask, vocab_list=vocab_list)
|
|
||||||
doc_instances = [json.dumps(instance) for instance in doc_instances]
|
|
||||||
for instance in doc_instances:
|
|
||||||
epoch_file.write(instance + '\n')
|
|
||||||
num_instances += 1
|
|
||||||
metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json"
|
|
||||||
with metrics_file.open('w') as metrics_file:
|
|
||||||
metrics = {
|
|
||||||
"num_training_examples": num_instances,
|
|
||||||
"max_seq_len": args.max_seq_len
|
|
||||||
}
|
|
||||||
metrics_file.write(json.dumps(metrics))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Reference in New Issue
Block a user