updating examples

This commit is contained in:
thomwolf
2018-11-17 10:30:54 +01:00
parent d0673c7dbd
commit 4e46affc34
4 changed files with 19 additions and 70 deletions

View File

@@ -699,11 +699,9 @@ def main():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--bert_config_file", default=None, type=str, required=True,
help="The config json file corresponding to the pre-trained BERT model. "
"This specifies the model architecture.")
parser.add_argument("--vocab_file", default=None, type=str, required=True,
help="The vocabulary file that the BERT model was trained on.")
parser.add_argument("--bert_model", default=None, type=str, required=True,
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model checkpoints will be written.")
@@ -711,11 +709,6 @@ def main():
parser.add_argument("--train_file", default=None, type=str, help="SQuAD json for training. E.g., train-v1.1.json")
parser.add_argument("--predict_file", default=None, type=str,
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
parser.add_argument("--init_checkpoint", default=None, type=str,
help="Initial checkpoint (usually from a pre-trained BERT model).")
parser.add_argument("--do_lower_case", default=True, action='store_true',
help="Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
parser.add_argument("--max_seq_length", default=384, type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
"longer than this will be truncated, and sequences shorter than this will be padded.")
@@ -815,20 +808,11 @@ def main():
raise ValueError(
"If `do_predict` is True, then `predict_file` must be specified.")
bert_config = BertConfig.from_json_file(args.bert_config_file)
if args.max_seq_length > bert_config.max_position_embeddings:
raise ValueError(
"Cannot use sequence length %d because the BERT model "
"was only trained up to sequence length %d" %
(args.max_seq_length, bert_config.max_position_embeddings))
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
raise ValueError("Output directory () already exists and is not empty.")
os.makedirs(args.output_dir, exist_ok=True)
tokenizer = BertTokenizer(
vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
tokenizer = BertTokenizer.from_pretrained(args.bert_model)
train_examples = None
num_train_steps = None
@@ -839,9 +823,7 @@ def main():
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
# Prepare model
model = BertForQuestionAnswering(bert_config)
if args.init_checkpoint is not None:
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model = BertForQuestionAnswering.from_pretrained(args.bert_model)
if args.fp16:
model.half()
model.to(device)