From abb7d1ff6dc7e54bd029eb8703fb12caf24baf38 Mon Sep 17 00:00:00 2001 From: Matthew Carrigan Date: Thu, 21 Mar 2019 17:50:03 +0000 Subject: [PATCH] Added proper context management to ensure cleanup happens in the right order. --- .../pregenerate_training_data.py | 73 ++++++++++--------- 1 file changed, 39 insertions(+), 34 deletions(-) diff --git a/examples/lm_finetuning/pregenerate_training_data.py b/examples/lm_finetuning/pregenerate_training_data.py index 678dfbc0bc..498ab22333 100644 --- a/examples/lm_finetuning/pregenerate_training_data.py +++ b/examples/lm_finetuning/pregenerate_training_data.py @@ -23,6 +23,7 @@ class DocumentDatabase: self.documents = [] self.document_shelf = None self.document_shelf_filepath = None + self.temp_dir = None self.doc_lengths = [] self.doc_cumsum = None self.cumsum_max = None @@ -68,9 +69,14 @@ class DocumentDatabase: else: return self.documents[item] - def cleanup(self): + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, traceback): if self.document_shelf is not None: self.document_shelf.close() + if self.temp_dir is not None: + self.temp_dir.cleanup() def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens): @@ -247,40 +253,39 @@ def main(): tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) vocab_list = list(tokenizer.vocab.keys()) - docs = DocumentDatabase(reduce_memory=args.reduce_memory) - with args.train_corpus.open() as f: - doc = [] - for line in tqdm(f, desc="Loading Dataset", unit=" lines"): - line = line.strip() - if line == "": - docs.add_document(doc) - doc = [] - else: - tokens = tokenizer.tokenize(line) - doc.append(tokens) + with DocumentDatabase(reduce_memory=args.reduce_memory) as docs: + with args.train_corpus.open() as f: + doc = [] + for line in tqdm(f, desc="Loading Dataset", unit=" lines"): + line = line.strip() + if line == "": + docs.add_document(doc) + doc = [] + else: + tokens = tokenizer.tokenize(line) + doc.append(tokens) - args.output_dir.mkdir(exist_ok=True) - for epoch in trange(args.epochs_to_generate, desc="Epoch"): - epoch_filename = args.output_dir / f"epoch_{epoch}.json" - num_instances = 0 - with epoch_filename.open('w') as epoch_file: - for doc_idx in trange(len(docs), desc="Document"): - doc_instances = create_instances_from_document( - docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob, - masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq, - vocab_list=vocab_list) - doc_instances = [json.dumps(instance) for instance in doc_instances] - for instance in doc_instances: - epoch_file.write(instance + '\n') - num_instances += 1 - metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json" - with metrics_file.open('w') as metrics_file: - metrics = { - "num_training_examples": num_instances, - "max_seq_len": args.max_seq_len - } - metrics_file.write(json.dumps(metrics)) - docs.cleanup() + args.output_dir.mkdir(exist_ok=True) + for epoch in trange(args.epochs_to_generate, desc="Epoch"): + epoch_filename = args.output_dir / f"epoch_{epoch}.json" + num_instances = 0 + with epoch_filename.open('w') as epoch_file: + for doc_idx in trange(len(docs), desc="Document"): + doc_instances = create_instances_from_document( + docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob, + masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq, + vocab_list=vocab_list) + doc_instances = [json.dumps(instance) for instance in doc_instances] + for instance in doc_instances: + epoch_file.write(instance + '\n') + num_instances += 1 + metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json" + with metrics_file.open('w') as metrics_file: + metrics = { + "num_training_examples": num_instances, + "max_seq_len": args.max_seq_len + } + metrics_file.write(json.dumps(metrics)) if __name__ == '__main__':