Syncing up argument names between the scripts
This commit is contained in:
@@ -201,8 +201,8 @@ def create_instances_from_document(
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument('--corpus_path', type=Path, required=True)
|
parser.add_argument('--train_corpus', type=Path, required=True)
|
||||||
parser.add_argument("--save_dir", type=Path, required=True)
|
parser.add_argument("--output_dir", type=Path, required=True)
|
||||||
parser.add_argument("--bert_model", type=str, required=True,
|
parser.add_argument("--bert_model", type=str, required=True,
|
||||||
choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased",
|
choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased",
|
||||||
"bert-base-multilingual", "bert-base-chinese"])
|
"bert-base-multilingual", "bert-base-chinese"])
|
||||||
@@ -229,7 +229,7 @@ def main():
|
|||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||||
vocab_list = list(tokenizer.vocab.keys())
|
vocab_list = list(tokenizer.vocab.keys())
|
||||||
with args.corpus_path.open() as f:
|
with args.train_corpus.open() as f:
|
||||||
docs = []
|
docs = []
|
||||||
doc = []
|
doc = []
|
||||||
for line in tqdm(f, desc="Loading Dataset"):
|
for line in tqdm(f, desc="Loading Dataset"):
|
||||||
@@ -241,7 +241,7 @@ def main():
|
|||||||
tokens = tokenizer.tokenize(line)
|
tokens = tokenizer.tokenize(line)
|
||||||
doc.append(tokens)
|
doc.append(tokens)
|
||||||
|
|
||||||
args.save_dir.mkdir(exist_ok=True)
|
args.output_dir.mkdir(exist_ok=True)
|
||||||
docs = DocumentDatabase(docs)
|
docs = DocumentDatabase(docs)
|
||||||
# When choosing a random sentence, we should sample docs proportionally to the number of sentences they contain
|
# When choosing a random sentence, we should sample docs proportionally to the number of sentences they contain
|
||||||
# Google BERT doesn't do this, and as a result oversamples shorter docs
|
# Google BERT doesn't do this, and as a result oversamples shorter docs
|
||||||
@@ -256,8 +256,8 @@ def main():
|
|||||||
epoch_instances.extend(doc_instances)
|
epoch_instances.extend(doc_instances)
|
||||||
|
|
||||||
shuffle(epoch_instances)
|
shuffle(epoch_instances)
|
||||||
epoch_file = args.save_dir / f"epoch_{epoch}.json"
|
epoch_file = args.output_dir / f"epoch_{epoch}.json"
|
||||||
metrics_file = args.save_dir / f"epoch_{epoch}_metrics.json"
|
metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json"
|
||||||
with epoch_file.open('w') as out_file:
|
with epoch_file.open('w') as out_file:
|
||||||
for instance in epoch_instances:
|
for instance in epoch_instances:
|
||||||
out_file.write(instance + '\n')
|
out_file.write(instance + '\n')
|
||||||
|
|||||||
@@ -401,7 +401,7 @@ def main():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--train_file",
|
parser.add_argument("--train_corpus",
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
@@ -511,8 +511,8 @@ def main():
|
|||||||
#train_examples = None
|
#train_examples = None
|
||||||
num_train_optimization_steps = None
|
num_train_optimization_steps = None
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
print("Loading Train Dataset", args.train_file)
|
print("Loading Train Dataset", args.train_corpus)
|
||||||
train_dataset = BERTDataset(args.train_file, tokenizer, seq_len=args.max_seq_length,
|
train_dataset = BERTDataset(args.train_corpus, tokenizer, seq_len=args.max_seq_length,
|
||||||
corpus_lines=None, on_memory=args.on_memory)
|
corpus_lines=None, on_memory=args.on_memory)
|
||||||
num_train_optimization_steps = int(
|
num_train_optimization_steps = int(
|
||||||
len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
|
len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
|
||||||
|
|||||||
Reference in New Issue
Block a user