Replaced some randints with cleaner randranges, and added a helpful
error for users whose corpus is just one giant document.
This commit is contained in:
@@ -4,7 +4,7 @@ from tqdm import tqdm, trange
|
|||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
import shelve
|
import shelve
|
||||||
|
|
||||||
from random import random, randint, shuffle, choice, sample
|
from random import random, randrange, randint, shuffle, choice, sample
|
||||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import json
|
import json
|
||||||
@@ -30,6 +30,8 @@ class DocumentDatabase:
|
|||||||
self.reduce_memory = reduce_memory
|
self.reduce_memory = reduce_memory
|
||||||
|
|
||||||
def add_document(self, document):
|
def add_document(self, document):
|
||||||
|
if not document:
|
||||||
|
return
|
||||||
if self.reduce_memory:
|
if self.reduce_memory:
|
||||||
current_idx = len(self.doc_lengths)
|
current_idx = len(self.doc_lengths)
|
||||||
self.document_shelf[str(current_idx)] = document
|
self.document_shelf[str(current_idx)] = document
|
||||||
@@ -49,11 +51,11 @@ class DocumentDatabase:
|
|||||||
self._precalculate_doc_weights()
|
self._precalculate_doc_weights()
|
||||||
rand_start = self.doc_cumsum[current_idx]
|
rand_start = self.doc_cumsum[current_idx]
|
||||||
rand_end = rand_start + self.cumsum_max - self.doc_lengths[current_idx]
|
rand_end = rand_start + self.cumsum_max - self.doc_lengths[current_idx]
|
||||||
sentence_index = randint(rand_start, rand_end-1) % self.cumsum_max
|
sentence_index = randrange(rand_start, rand_end) % self.cumsum_max
|
||||||
sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right')
|
sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right')
|
||||||
else:
|
else:
|
||||||
# If we don't use sentence weighting, then every doc has an equal chance to be chosen
|
# If we don't use sentence weighting, then every doc has an equal chance to be chosen
|
||||||
sampled_doc_index = current_idx + randint(1, len(self.doc_lengths)-1)
|
sampled_doc_index = (current_idx + randrange(1, len(self.doc_lengths))) % len(self.doc_lengths)
|
||||||
assert sampled_doc_index != current_idx
|
assert sampled_doc_index != current_idx
|
||||||
if self.reduce_memory:
|
if self.reduce_memory:
|
||||||
return self.document_shelf[str(sampled_doc_index)]
|
return self.document_shelf[str(sampled_doc_index)]
|
||||||
@@ -170,7 +172,7 @@ def create_instances_from_document(
|
|||||||
# (first) sentence.
|
# (first) sentence.
|
||||||
a_end = 1
|
a_end = 1
|
||||||
if len(current_chunk) >= 2:
|
if len(current_chunk) >= 2:
|
||||||
a_end = randint(1, len(current_chunk) - 1)
|
a_end = randrange(1, len(current_chunk))
|
||||||
|
|
||||||
tokens_a = []
|
tokens_a = []
|
||||||
for j in range(a_end):
|
for j in range(a_end):
|
||||||
@@ -186,7 +188,7 @@ def create_instances_from_document(
|
|||||||
# Sample a random document, with longer docs being sampled more frequently
|
# Sample a random document, with longer docs being sampled more frequently
|
||||||
random_document = doc_database.sample_doc(current_idx=doc_idx, sentence_weighted=True)
|
random_document = doc_database.sample_doc(current_idx=doc_idx, sentence_weighted=True)
|
||||||
|
|
||||||
random_start = randint(0, len(random_document) - 1)
|
random_start = randrange(0, len(random_document))
|
||||||
for j in range(random_start, len(random_document)):
|
for j in range(random_start, len(random_document)):
|
||||||
tokens_b.extend(random_document[j])
|
tokens_b.extend(random_document[j])
|
||||||
if len(tokens_b) >= target_b_length:
|
if len(tokens_b) >= target_b_length:
|
||||||
@@ -264,6 +266,14 @@ def main():
|
|||||||
else:
|
else:
|
||||||
tokens = tokenizer.tokenize(line)
|
tokens = tokenizer.tokenize(line)
|
||||||
doc.append(tokens)
|
doc.append(tokens)
|
||||||
|
if doc:
|
||||||
|
docs.add_document(doc) # If the last doc didn't end on a newline, make sure it still gets added
|
||||||
|
if len(docs) <= 1:
|
||||||
|
exit("ERROR: No document breaks were found in the input file! These are necessary to allow the script to "
|
||||||
|
"ensure that random NextSentences are not sampled from the same document. Please add blank lines to "
|
||||||
|
"indicate breaks between documents in your input file. If your dataset does not contain multiple "
|
||||||
|
"documents, blank lines can be inserted at any natural boundary, such as the ends of chapters, "
|
||||||
|
"sections or paragraphs.")
|
||||||
|
|
||||||
args.output_dir.mkdir(exist_ok=True)
|
args.output_dir.mkdir(exist_ok=True)
|
||||||
for epoch in trange(args.epochs_to_generate, desc="Epoch"):
|
for epoch in trange(args.epochs_to_generate, desc="Epoch"):
|
||||||
|
|||||||
Reference in New Issue
Block a user