From dbbd6c7500dded778706326c7a1e402cffe97eb8 Mon Sep 17 00:00:00 2001 From: Matthew Carrigan Date: Fri, 12 Apr 2019 15:07:58 +0100 Subject: [PATCH] Replaced some randints with cleaner randranges, and added a helpful error for users whose corpus is just one giant document. --- .../pregenerate_training_data.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/examples/lm_finetuning/pregenerate_training_data.py b/examples/lm_finetuning/pregenerate_training_data.py index 8cc28d2e78..e6c3598a9f 100644 --- a/examples/lm_finetuning/pregenerate_training_data.py +++ b/examples/lm_finetuning/pregenerate_training_data.py @@ -4,7 +4,7 @@ from tqdm import tqdm, trange from tempfile import TemporaryDirectory import shelve -from random import random, randint, shuffle, choice, sample +from random import random, randrange, randint, shuffle, choice, sample from pytorch_pretrained_bert.tokenization import BertTokenizer import numpy as np import json @@ -30,6 +30,8 @@ class DocumentDatabase: self.reduce_memory = reduce_memory def add_document(self, document): + if not document: + return if self.reduce_memory: current_idx = len(self.doc_lengths) self.document_shelf[str(current_idx)] = document @@ -49,11 +51,11 @@ class DocumentDatabase: self._precalculate_doc_weights() rand_start = self.doc_cumsum[current_idx] rand_end = rand_start + self.cumsum_max - self.doc_lengths[current_idx] - sentence_index = randint(rand_start, rand_end-1) % self.cumsum_max + sentence_index = randrange(rand_start, rand_end) % self.cumsum_max sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right') else: # If we don't use sentence weighting, then every doc has an equal chance to be chosen - sampled_doc_index = current_idx + randint(1, len(self.doc_lengths)-1) + sampled_doc_index = (current_idx + randrange(1, len(self.doc_lengths))) % len(self.doc_lengths) assert sampled_doc_index != current_idx if self.reduce_memory: return self.document_shelf[str(sampled_doc_index)] @@ -170,7 +172,7 @@ def create_instances_from_document( # (first) sentence. a_end = 1 if len(current_chunk) >= 2: - a_end = randint(1, len(current_chunk) - 1) + a_end = randrange(1, len(current_chunk)) tokens_a = [] for j in range(a_end): @@ -186,7 +188,7 @@ def create_instances_from_document( # Sample a random document, with longer docs being sampled more frequently random_document = doc_database.sample_doc(current_idx=doc_idx, sentence_weighted=True) - random_start = randint(0, len(random_document) - 1) + random_start = randrange(0, len(random_document)) for j in range(random_start, len(random_document)): tokens_b.extend(random_document[j]) if len(tokens_b) >= target_b_length: @@ -264,6 +266,14 @@ def main(): else: tokens = tokenizer.tokenize(line) doc.append(tokens) + if doc: + docs.add_document(doc) # If the last doc didn't end on a newline, make sure it still gets added + if len(docs) <= 1: + exit("ERROR: No document breaks were found in the input file! These are necessary to allow the script to " + "ensure that random NextSentences are not sampled from the same document. Please add blank lines to " + "indicate breaks between documents in your input file. If your dataset does not contain multiple " + "documents, blank lines can be inserted at any natural boundary, such as the ends of chapters, " + "sections or paragraphs.") args.output_dir.mkdir(exist_ok=True) for epoch in trange(args.epochs_to_generate, desc="Epoch"):