Added a --reduce_memory option to shelve docs to disc instead of keeping them in memory.
This commit is contained in:
@@ -1,46 +1,81 @@
|
|||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
import shelve
|
||||||
|
|
||||||
from random import random, randint, shuffle, choice, sample
|
from random import random, randint, shuffle, choice, sample
|
||||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||||
|
import numpy as np
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
class DocumentDatabase:
|
class DocumentDatabase:
|
||||||
def __init__(self, document_list):
|
def __init__(self, reduce_memory=False, working_dir=None):
|
||||||
self.document_list = document_list
|
if reduce_memory:
|
||||||
self.doc_starts = {}
|
if working_dir is None:
|
||||||
self.weighted_doc_samples = []
|
self.temp_dir = TemporaryDirectory()
|
||||||
i = 0
|
self.working_dir = Path(self.temp_dir.name)
|
||||||
for doc_idx, doc in enumerate(document_list):
|
else:
|
||||||
self.doc_starts[doc_idx] = i
|
self.temp_dir = None
|
||||||
self.weighted_doc_samples.extend([doc_idx] * len(doc))
|
self.working_dir = Path(working_dir)
|
||||||
i += len(doc)
|
self.working_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.document_shelf_filepath = self.working_dir / 'shelf.db'
|
||||||
|
self.document_shelf = shelve.open(str(self.document_shelf_filepath),
|
||||||
|
flag='n', protocol=-1)
|
||||||
|
self.documents = None
|
||||||
|
else:
|
||||||
|
self.documents = []
|
||||||
|
self.document_shelf = None
|
||||||
|
self.document_shelf_filepath = None
|
||||||
|
self.doc_lengths = []
|
||||||
|
self.doc_cumsum = None
|
||||||
|
self.cumsum_max = None
|
||||||
|
self.reduce_memory = reduce_memory
|
||||||
|
|
||||||
|
def add_document(self, document):
|
||||||
|
if self.reduce_memory:
|
||||||
|
current_idx = len(self.doc_lengths)
|
||||||
|
self.document_shelf[str(current_idx)] = document
|
||||||
|
else:
|
||||||
|
self.documents.append(document)
|
||||||
|
self.doc_lengths.append(len(document))
|
||||||
|
|
||||||
|
def _precalculate_doc_weights(self):
|
||||||
|
self.doc_cumsum = np.cumsum(self.doc_lengths)
|
||||||
|
self.cumsum_max = self.doc_cumsum[-1]
|
||||||
|
|
||||||
def sample_doc(self, current_idx, sentence_weighted=True):
|
def sample_doc(self, current_idx, sentence_weighted=True):
|
||||||
# Uses the current iteration counter to ensure we don't sample the same doc twice
|
# Uses the current iteration counter to ensure we don't sample the same doc twice
|
||||||
if sentence_weighted:
|
if sentence_weighted:
|
||||||
num_sentences = len(self.document_list[current_idx])
|
# With sentence weighting, we sample docs proportionally to their sentence length
|
||||||
# This very painful line randomly selects a document, weighted by the number of sentences they contain,
|
if self.doc_cumsum is None or len(self.doc_cumsum) != len(self.doc_lengths):
|
||||||
# while guaranteeing that it won't return the original document
|
self._precalculate_doc_weights()
|
||||||
sampled_val = (
|
rand_start = self.doc_cumsum[current_idx]
|
||||||
(self.doc_starts[current_idx] + num_sentences
|
rand_end = rand_start + self.cumsum_max - self.doc_lengths[current_idx]
|
||||||
+ randint(0, len(self.weighted_doc_samples) - num_sentences - 1))
|
sentence_index = randint(rand_start, rand_end) % self.cumsum_max
|
||||||
% len(self.weighted_doc_samples))
|
sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right')
|
||||||
sampled_doc_index = self.weighted_doc_samples[sampled_val]
|
|
||||||
else:
|
else:
|
||||||
# If we don't use sentence weighting, then every doc has an equal chance to be chosen
|
# If we don't use sentence weighting, then every doc has an equal chance to be chosen
|
||||||
sampled_doc_index = current_idx + randint(1, len(self.document_list)-1)
|
sampled_doc_index = current_idx + randint(1, len(self.doc_lengths)-1)
|
||||||
assert sampled_doc_index != current_idx
|
assert sampled_doc_index != current_idx
|
||||||
return self.document_list[sampled_doc_index]
|
if self.reduce_memory:
|
||||||
|
return self.document_shelf[str(sampled_doc_index)]
|
||||||
|
else:
|
||||||
|
return self.documents[sampled_doc_index]
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.document_list)
|
return len(self.doc_lengths)
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
return self.document_list[item]
|
if self.reduce_memory:
|
||||||
|
return self.document_shelf[str(item)]
|
||||||
|
else:
|
||||||
|
return self.documents[item]
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
if self.document_shelf is not None:
|
||||||
|
self.document_shelf.close()
|
||||||
|
|
||||||
|
|
||||||
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
|
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
|
||||||
@@ -200,6 +235,11 @@ def main():
|
|||||||
"bert-base-multilingual", "bert-base-chinese"])
|
"bert-base-multilingual", "bert-base-chinese"])
|
||||||
parser.add_argument("--do_lower_case", action="store_true")
|
parser.add_argument("--do_lower_case", action="store_true")
|
||||||
|
|
||||||
|
parser.add_argument("--reduce_memory", action="store_true",
|
||||||
|
help="Reduce memory usage for large datasets by keeping data on disc rather than in memory")
|
||||||
|
parser.add_argument("--working_dir", type=Path, default=None,
|
||||||
|
help="Temporary directory to use for --reduce_memory. If not set, uses TemporaryDirectory()")
|
||||||
|
|
||||||
parser.add_argument("--epochs_to_generate", type=int, default=3,
|
parser.add_argument("--epochs_to_generate", type=int, default=3,
|
||||||
help="Number of epochs of data to pregenerate")
|
help="Number of epochs of data to pregenerate")
|
||||||
parser.add_argument("--max_seq_len", type=int, default=128)
|
parser.add_argument("--max_seq_len", type=int, default=128)
|
||||||
@@ -212,31 +252,21 @@ def main():
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# TODO Add a low-memory / multiprocessing path for very large datasets
|
|
||||||
# In this path documents would be stored in a shelf after being tokenized, and multiple processes would convert
|
|
||||||
# those docs into training examples that would be written out on the fly. This would avoid the need to keep
|
|
||||||
# the whole training set in memory and would speed up dataset creation at the cost of code complexity.
|
|
||||||
# In addition, the finetuning script would need to be modified
|
|
||||||
# to store the training epochs as memmapped arrays.
|
|
||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||||
vocab_list = list(tokenizer.vocab.keys())
|
vocab_list = list(tokenizer.vocab.keys())
|
||||||
|
docs = DocumentDatabase(reduce_memory=args.reduce_memory, working_dir=args.working_dir)
|
||||||
with args.train_corpus.open() as f:
|
with args.train_corpus.open() as f:
|
||||||
docs = []
|
|
||||||
doc = []
|
doc = []
|
||||||
for line in tqdm(f, desc="Loading Dataset"):
|
for line in tqdm(f, desc="Loading Dataset", unit=" lines"):
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if line == "":
|
if line == "":
|
||||||
docs.append(doc)
|
docs.add_document(doc)
|
||||||
doc = []
|
doc = []
|
||||||
else:
|
else:
|
||||||
tokens = tokenizer.tokenize(line)
|
tokens = tokenizer.tokenize(line)
|
||||||
doc.append(tokens)
|
doc.append(tokens)
|
||||||
|
|
||||||
args.output_dir.mkdir(exist_ok=True)
|
args.output_dir.mkdir(exist_ok=True)
|
||||||
docs = DocumentDatabase(docs)
|
|
||||||
# When choosing a random sentence, we should sample docs proportionally to the number of sentences they contain
|
|
||||||
# Google BERT doesn't do this, and as a result oversamples shorter docs
|
|
||||||
for epoch in trange(args.epochs_to_generate, desc="Epoch"):
|
for epoch in trange(args.epochs_to_generate, desc="Epoch"):
|
||||||
epoch_filename = args.output_dir / f"epoch_{epoch}.json"
|
epoch_filename = args.output_dir / f"epoch_{epoch}.json"
|
||||||
num_instances = 0
|
num_instances = 0
|
||||||
@@ -257,6 +287,7 @@ def main():
|
|||||||
"max_seq_len": args.max_seq_len
|
"max_seq_len": args.max_seq_len
|
||||||
}
|
}
|
||||||
metrics_file.write(json.dumps(metrics))
|
metrics_file.write(json.dumps(metrics))
|
||||||
|
docs.cleanup()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Reference in New Issue
Block a user