Added proper context management to ensure cleanup happens in the right
order.
This commit is contained in:
@@ -23,6 +23,7 @@ class DocumentDatabase:
|
|||||||
self.documents = []
|
self.documents = []
|
||||||
self.document_shelf = None
|
self.document_shelf = None
|
||||||
self.document_shelf_filepath = None
|
self.document_shelf_filepath = None
|
||||||
|
self.temp_dir = None
|
||||||
self.doc_lengths = []
|
self.doc_lengths = []
|
||||||
self.doc_cumsum = None
|
self.doc_cumsum = None
|
||||||
self.cumsum_max = None
|
self.cumsum_max = None
|
||||||
@@ -68,9 +69,14 @@ class DocumentDatabase:
|
|||||||
else:
|
else:
|
||||||
return self.documents[item]
|
return self.documents[item]
|
||||||
|
|
||||||
def cleanup(self):
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, traceback):
|
||||||
if self.document_shelf is not None:
|
if self.document_shelf is not None:
|
||||||
self.document_shelf.close()
|
self.document_shelf.close()
|
||||||
|
if self.temp_dir is not None:
|
||||||
|
self.temp_dir.cleanup()
|
||||||
|
|
||||||
|
|
||||||
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
|
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
|
||||||
@@ -247,40 +253,39 @@ def main():
|
|||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||||
vocab_list = list(tokenizer.vocab.keys())
|
vocab_list = list(tokenizer.vocab.keys())
|
||||||
docs = DocumentDatabase(reduce_memory=args.reduce_memory)
|
with DocumentDatabase(reduce_memory=args.reduce_memory) as docs:
|
||||||
with args.train_corpus.open() as f:
|
with args.train_corpus.open() as f:
|
||||||
doc = []
|
doc = []
|
||||||
for line in tqdm(f, desc="Loading Dataset", unit=" lines"):
|
for line in tqdm(f, desc="Loading Dataset", unit=" lines"):
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if line == "":
|
if line == "":
|
||||||
docs.add_document(doc)
|
docs.add_document(doc)
|
||||||
doc = []
|
doc = []
|
||||||
else:
|
else:
|
||||||
tokens = tokenizer.tokenize(line)
|
tokens = tokenizer.tokenize(line)
|
||||||
doc.append(tokens)
|
doc.append(tokens)
|
||||||
|
|
||||||
args.output_dir.mkdir(exist_ok=True)
|
args.output_dir.mkdir(exist_ok=True)
|
||||||
for epoch in trange(args.epochs_to_generate, desc="Epoch"):
|
for epoch in trange(args.epochs_to_generate, desc="Epoch"):
|
||||||
epoch_filename = args.output_dir / f"epoch_{epoch}.json"
|
epoch_filename = args.output_dir / f"epoch_{epoch}.json"
|
||||||
num_instances = 0
|
num_instances = 0
|
||||||
with epoch_filename.open('w') as epoch_file:
|
with epoch_filename.open('w') as epoch_file:
|
||||||
for doc_idx in trange(len(docs), desc="Document"):
|
for doc_idx in trange(len(docs), desc="Document"):
|
||||||
doc_instances = create_instances_from_document(
|
doc_instances = create_instances_from_document(
|
||||||
docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob,
|
docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob,
|
||||||
masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq,
|
masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq,
|
||||||
vocab_list=vocab_list)
|
vocab_list=vocab_list)
|
||||||
doc_instances = [json.dumps(instance) for instance in doc_instances]
|
doc_instances = [json.dumps(instance) for instance in doc_instances]
|
||||||
for instance in doc_instances:
|
for instance in doc_instances:
|
||||||
epoch_file.write(instance + '\n')
|
epoch_file.write(instance + '\n')
|
||||||
num_instances += 1
|
num_instances += 1
|
||||||
metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json"
|
metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json"
|
||||||
with metrics_file.open('w') as metrics_file:
|
with metrics_file.open('w') as metrics_file:
|
||||||
metrics = {
|
metrics = {
|
||||||
"num_training_examples": num_instances,
|
"num_training_examples": num_instances,
|
||||||
"max_seq_len": args.max_seq_len
|
"max_seq_len": args.max_seq_len
|
||||||
}
|
}
|
||||||
metrics_file.write(json.dumps(metrics))
|
metrics_file.write(json.dumps(metrics))
|
||||||
docs.cleanup()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Reference in New Issue
Block a user