From 2bba7f810e37e3cd64b5a74de29b58960af6d7b5 Mon Sep 17 00:00:00 2001 From: Matthew Carrigan Date: Thu, 21 Mar 2019 16:50:16 +0000 Subject: [PATCH] Added a --reduce_memory option to shelve docs to disc instead of keeping them in memory. --- .../pregenerate_training_data.py | 101 ++++++++++++------ 1 file changed, 66 insertions(+), 35 deletions(-) diff --git a/examples/lm_finetuning/pregenerate_training_data.py b/examples/lm_finetuning/pregenerate_training_data.py index ac7a1ae076..98e04eaf26 100644 --- a/examples/lm_finetuning/pregenerate_training_data.py +++ b/examples/lm_finetuning/pregenerate_training_data.py @@ -1,46 +1,81 @@ from argparse import ArgumentParser from pathlib import Path from tqdm import tqdm, trange +from tempfile import TemporaryDirectory +import shelve from random import random, randint, shuffle, choice, sample from pytorch_pretrained_bert.tokenization import BertTokenizer - +import numpy as np import json class DocumentDatabase: - def __init__(self, document_list): - self.document_list = document_list - self.doc_starts = {} - self.weighted_doc_samples = [] - i = 0 - for doc_idx, doc in enumerate(document_list): - self.doc_starts[doc_idx] = i - self.weighted_doc_samples.extend([doc_idx] * len(doc)) - i += len(doc) + def __init__(self, reduce_memory=False, working_dir=None): + if reduce_memory: + if working_dir is None: + self.temp_dir = TemporaryDirectory() + self.working_dir = Path(self.temp_dir.name) + else: + self.temp_dir = None + self.working_dir = Path(working_dir) + self.working_dir.mkdir(parents=True, exist_ok=True) + self.document_shelf_filepath = self.working_dir / 'shelf.db' + self.document_shelf = shelve.open(str(self.document_shelf_filepath), + flag='n', protocol=-1) + self.documents = None + else: + self.documents = [] + self.document_shelf = None + self.document_shelf_filepath = None + self.doc_lengths = [] + self.doc_cumsum = None + self.cumsum_max = None + self.reduce_memory = reduce_memory + + def add_document(self, document): + if self.reduce_memory: + current_idx = len(self.doc_lengths) + self.document_shelf[str(current_idx)] = document + else: + self.documents.append(document) + self.doc_lengths.append(len(document)) + + def _precalculate_doc_weights(self): + self.doc_cumsum = np.cumsum(self.doc_lengths) + self.cumsum_max = self.doc_cumsum[-1] def sample_doc(self, current_idx, sentence_weighted=True): # Uses the current iteration counter to ensure we don't sample the same doc twice if sentence_weighted: - num_sentences = len(self.document_list[current_idx]) - # This very painful line randomly selects a document, weighted by the number of sentences they contain, - # while guaranteeing that it won't return the original document - sampled_val = ( - (self.doc_starts[current_idx] + num_sentences - + randint(0, len(self.weighted_doc_samples) - num_sentences - 1)) - % len(self.weighted_doc_samples)) - sampled_doc_index = self.weighted_doc_samples[sampled_val] + # With sentence weighting, we sample docs proportionally to their sentence length + if self.doc_cumsum is None or len(self.doc_cumsum) != len(self.doc_lengths): + self._precalculate_doc_weights() + rand_start = self.doc_cumsum[current_idx] + rand_end = rand_start + self.cumsum_max - self.doc_lengths[current_idx] + sentence_index = randint(rand_start, rand_end) % self.cumsum_max + sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right') else: # If we don't use sentence weighting, then every doc has an equal chance to be chosen - sampled_doc_index = current_idx + randint(1, len(self.document_list)-1) + sampled_doc_index = current_idx + randint(1, len(self.doc_lengths)-1) assert sampled_doc_index != current_idx - return self.document_list[sampled_doc_index] + if self.reduce_memory: + return self.document_shelf[str(sampled_doc_index)] + else: + return self.documents[sampled_doc_index] def __len__(self): - return len(self.document_list) + return len(self.doc_lengths) def __getitem__(self, item): - return self.document_list[item] + if self.reduce_memory: + return self.document_shelf[str(item)] + else: + return self.documents[item] + + def cleanup(self): + if self.document_shelf is not None: + self.document_shelf.close() def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens): @@ -200,6 +235,11 @@ def main(): "bert-base-multilingual", "bert-base-chinese"]) parser.add_argument("--do_lower_case", action="store_true") + parser.add_argument("--reduce_memory", action="store_true", + help="Reduce memory usage for large datasets by keeping data on disc rather than in memory") + parser.add_argument("--working_dir", type=Path, default=None, + help="Temporary directory to use for --reduce_memory. If not set, uses TemporaryDirectory()") + parser.add_argument("--epochs_to_generate", type=int, default=3, help="Number of epochs of data to pregenerate") parser.add_argument("--max_seq_len", type=int, default=128) @@ -212,31 +252,21 @@ def main(): args = parser.parse_args() - # TODO Add a low-memory / multiprocessing path for very large datasets - # In this path documents would be stored in a shelf after being tokenized, and multiple processes would convert - # those docs into training examples that would be written out on the fly. This would avoid the need to keep - # the whole training set in memory and would speed up dataset creation at the cost of code complexity. - # In addition, the finetuning script would need to be modified - # to store the training epochs as memmapped arrays. - tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) vocab_list = list(tokenizer.vocab.keys()) + docs = DocumentDatabase(reduce_memory=args.reduce_memory, working_dir=args.working_dir) with args.train_corpus.open() as f: - docs = [] doc = [] - for line in tqdm(f, desc="Loading Dataset"): + for line in tqdm(f, desc="Loading Dataset", unit=" lines"): line = line.strip() if line == "": - docs.append(doc) + docs.add_document(doc) doc = [] else: tokens = tokenizer.tokenize(line) doc.append(tokens) args.output_dir.mkdir(exist_ok=True) - docs = DocumentDatabase(docs) - # When choosing a random sentence, we should sample docs proportionally to the number of sentences they contain - # Google BERT doesn't do this, and as a result oversamples shorter docs for epoch in trange(args.epochs_to_generate, desc="Epoch"): epoch_filename = args.output_dir / f"epoch_{epoch}.json" num_instances = 0 @@ -257,6 +287,7 @@ def main(): "max_seq_len": args.max_seq_len } metrics_file.write(json.dumps(metrics)) + docs.cleanup() if __name__ == '__main__':