Merge branch 'master' into master

This commit is contained in:
Thomas Wolf
2019-04-15 10:58:34 +02:00
committed by GitHub
8 changed files with 46 additions and 18 deletions

View File

@@ -37,6 +37,7 @@ python3 simple_lm_finetuning.py
--bert_model bert-base-uncased
--do_lower_case
--output_dir finetuned_lm/
--do_train
```
### Pregenerating training data
@@ -60,4 +61,4 @@ python3 finetune_on_pregenerated.py
--do_lower_case
--output_dir finetuned_lm/
--epochs 3
```
```

View File

@@ -123,9 +123,8 @@ def main():
parser = ArgumentParser()
parser.add_argument('--pregenerated_data', type=Path, required=True)
parser.add_argument('--output_dir', type=Path, required=True)
parser.add_argument("--bert_model", type=str, required=True,
choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased",
"bert-base-multilingual", "bert-base-chinese"])
parser.add_argument("--bert_model", type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
parser.add_argument("--do_lower_case", action="store_true")
parser.add_argument("--reduce_memory", action="store_true",
help="Store training data as on-disc memmaps to massively reduce memory usage")

View File

@@ -4,7 +4,7 @@ from tqdm import tqdm, trange
from tempfile import TemporaryDirectory
import shelve
from random import random, randint, shuffle, choice, sample
from random import random, randrange, randint, shuffle, choice, sample
from pytorch_pretrained_bert.tokenization import BertTokenizer
import numpy as np
import json
@@ -30,6 +30,8 @@ class DocumentDatabase:
self.reduce_memory = reduce_memory
def add_document(self, document):
if not document:
return
if self.reduce_memory:
current_idx = len(self.doc_lengths)
self.document_shelf[str(current_idx)] = document
@@ -49,11 +51,11 @@ class DocumentDatabase:
self._precalculate_doc_weights()
rand_start = self.doc_cumsum[current_idx]
rand_end = rand_start + self.cumsum_max - self.doc_lengths[current_idx]
sentence_index = randint(rand_start, rand_end-1) % self.cumsum_max
sentence_index = randrange(rand_start, rand_end) % self.cumsum_max
sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right')
else:
# If we don't use sentence weighting, then every doc has an equal chance to be chosen
sampled_doc_index = current_idx + randint(1, len(self.doc_lengths)-1)
sampled_doc_index = (current_idx + randrange(1, len(self.doc_lengths))) % len(self.doc_lengths)
assert sampled_doc_index != current_idx
if self.reduce_memory:
return self.document_shelf[str(sampled_doc_index)]
@@ -170,7 +172,7 @@ def create_instances_from_document(
# (first) sentence.
a_end = 1
if len(current_chunk) >= 2:
a_end = randint(1, len(current_chunk) - 1)
a_end = randrange(1, len(current_chunk))
tokens_a = []
for j in range(a_end):
@@ -186,7 +188,7 @@ def create_instances_from_document(
# Sample a random document, with longer docs being sampled more frequently
random_document = doc_database.sample_doc(current_idx=doc_idx, sentence_weighted=True)
random_start = randint(0, len(random_document) - 1)
random_start = randrange(0, len(random_document))
for j in range(random_start, len(random_document)):
tokens_b.extend(random_document[j])
if len(tokens_b) >= target_b_length:
@@ -264,6 +266,14 @@ def main():
else:
tokens = tokenizer.tokenize(line)
doc.append(tokens)
if doc:
docs.add_document(doc) # If the last doc didn't end on a newline, make sure it still gets added
if len(docs) <= 1:
exit("ERROR: No document breaks were found in the input file! These are necessary to allow the script to "
"ensure that random NextSentences are not sampled from the same document. Please add blank lines to "
"indicate breaks between documents in your input file. If your dataset does not contain multiple "
"documents, blank lines can be inserted at any natural boundary, such as the ends of chapters, "
"sections or paragraphs.")
args.output_dir.mkdir(exist_ok=True)
for epoch in trange(args.epochs_to_generate, desc="Epoch"):

View File

@@ -95,7 +95,7 @@ class DataProcessor(object):
@classmethod
def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with open(input_file, "r") as f:
with open(input_file, "r", encoding="utf-8") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:

View File

@@ -83,8 +83,9 @@ def run_model():
elif args.length > model.config.n_ctx:
raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)
if not args.unconditional:
while True:
while True:
context_tokens = []
if not args.unconditional:
raw_text = input("Model prompt >>> ")
while not raw_text:
print('Prompt should not be empty!')
@@ -123,6 +124,8 @@ def run_model():
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text)
print("=" * 80)
if args.unconditional:
break
if __name__ == '__main__':
run_model()