updating examples

This commit is contained in:
thomwolf
2018-11-17 10:30:54 +01:00
parent d0673c7dbd
commit 4e46affc34
4 changed files with 19 additions and 70 deletions

View File

@@ -193,23 +193,16 @@ def main():
## Required parameters
parser.add_argument("--input_file", default=None, type=str, required=True)
parser.add_argument("--vocab_file", default=None, type=str, required=True,
help="The vocabulary file that the BERT model was trained on.")
parser.add_argument("--output_file", default=None, type=str, required=True)
parser.add_argument("--bert_config_file", default=None, type=str, required=True,
help="The config json file corresponding to the pre-trained BERT model. "
"This specifies the model architecture.")
parser.add_argument("--init_checkpoint", default=None, type=str, required=True,
help="Initial checkpoint (usually from a pre-trained BERT model).")
parser.add_argument("--bert_model", default=None, type=str, required=True,
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
## Other parameters
parser.add_argument("--layers", default="-1,-2,-3,-4", type=str)
parser.add_argument("--max_seq_length", default=128, type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences longer "
"than this will be truncated, and sequences shorter than this will be padded.")
parser.add_argument("--do_lower_case", default=True, action='store_true',
help="Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
parser.add_argument("--batch_size", default=32, type=int, help="Batch size for predictions.")
parser.add_argument("--local_rank",
type=int,
@@ -230,10 +223,7 @@ def main():
layer_indexes = [int(x) for x in args.layers.split(",")]
bert_config = BertConfig.from_json_file(args.bert_config_file)
tokenizer = BertTokenizer(
vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
tokenizer = BertTokenizer.from_pretrained(args.bert_model)
examples = read_examples(args.input_file)
@@ -244,9 +234,7 @@ def main():
for feature in features:
unique_id_to_feature[feature.unique_id] = feature
model = BertModel(bert_config)
if args.init_checkpoint is not None:
model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model = BertModel.from_pretrained(args.bert_model)
model.to(device)
if args.local_rank != -1: