Reduced memory usage for pregenerating the data a lot by writing it
out on the fly without shuffling - the Sampler in the finetuning script will shuffle for us.
This commit is contained in:
@@ -73,7 +73,10 @@ class PregeneratedDataset(Dataset):
|
|||||||
logging.info(f"Loading training examples for epoch {epoch}")
|
logging.info(f"Loading training examples for epoch {epoch}")
|
||||||
with data_file.open() as f:
|
with data_file.open() as f:
|
||||||
for i, line in enumerate(tqdm(f, total=num_samples, desc="Training examples")):
|
for i, line in enumerate(tqdm(f, total=num_samples, desc="Training examples")):
|
||||||
example = json.loads(line.rstrip())
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue # Skip trailing blank lines etc.
|
||||||
|
example = json.loads(line)
|
||||||
features = convert_example_to_features(example, tokenizer, seq_len)
|
features = convert_example_to_features(example, tokenizer, seq_len)
|
||||||
input_ids[i] = features.input_ids
|
input_ids[i] = features.input_ids
|
||||||
segment_ids[i] = features.segment_ids
|
segment_ids[i] = features.segment_ids
|
||||||
|
|||||||
@@ -242,24 +242,22 @@ def main():
|
|||||||
# When choosing a random sentence, we should sample docs proportionally to the number of sentences they contain
|
# When choosing a random sentence, we should sample docs proportionally to the number of sentences they contain
|
||||||
# Google BERT doesn't do this, and as a result oversamples shorter docs
|
# Google BERT doesn't do this, and as a result oversamples shorter docs
|
||||||
for epoch in trange(args.epochs_to_generate, desc="Epoch"):
|
for epoch in trange(args.epochs_to_generate, desc="Epoch"):
|
||||||
epoch_instances = []
|
epoch_filename = args.output_dir / f"epoch_{epoch}.json"
|
||||||
|
num_instances = 0
|
||||||
|
with epoch_filename.open('w') as epoch_file:
|
||||||
for doc_idx in trange(len(docs), desc="Document"):
|
for doc_idx in trange(len(docs), desc="Document"):
|
||||||
doc_instances = create_instances_from_document(
|
doc_instances = create_instances_from_document(
|
||||||
docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob,
|
docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob,
|
||||||
masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq,
|
masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq,
|
||||||
vocab_list=vocab_list)
|
vocab_list=vocab_list)
|
||||||
doc_instances = [json.dumps(instance) for instance in doc_instances]
|
doc_instances = [json.dumps(instance) for instance in doc_instances]
|
||||||
epoch_instances.extend(doc_instances)
|
for instance in doc_instances:
|
||||||
|
epoch_file.write(instance + '\n')
|
||||||
shuffle(epoch_instances)
|
num_instances += 1
|
||||||
epoch_file = args.output_dir / f"epoch_{epoch}.json"
|
|
||||||
metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json"
|
metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json"
|
||||||
with epoch_file.open('w') as out_file:
|
|
||||||
for instance in epoch_instances:
|
|
||||||
out_file.write(instance + '\n')
|
|
||||||
with metrics_file.open('w') as metrics_file:
|
with metrics_file.open('w') as metrics_file:
|
||||||
metrics = {
|
metrics = {
|
||||||
"num_training_examples": len(epoch_instances),
|
"num_training_examples": num_instances,
|
||||||
"max_seq_len": args.max_seq_len
|
"max_seq_len": args.max_seq_len
|
||||||
}
|
}
|
||||||
metrics_file.write(json.dumps(metrics))
|
metrics_file.write(json.dumps(metrics))
|
||||||
|
|||||||
Reference in New Issue
Block a user