diff --git a/examples/extract_features.py b/examples/extract_features.py index 3ea1909bb3..fce90dffa2 100644 --- a/examples/extract_features.py +++ b/examples/extract_features.py @@ -193,23 +193,16 @@ def main(): ## Required parameters parser.add_argument("--input_file", default=None, type=str, required=True) - parser.add_argument("--vocab_file", default=None, type=str, required=True, - help="The vocabulary file that the BERT model was trained on.") parser.add_argument("--output_file", default=None, type=str, required=True) - parser.add_argument("--bert_config_file", default=None, type=str, required=True, - help="The config json file corresponding to the pre-trained BERT model. " - "This specifies the model architecture.") - parser.add_argument("--init_checkpoint", default=None, type=str, required=True, - help="Initial checkpoint (usually from a pre-trained BERT model).") + parser.add_argument("--bert_model", default=None, type=str, required=True, + help="Bert pre-trained model selected in the list: bert-base-uncased, " + "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") ## Other parameters parser.add_argument("--layers", default="-1,-2,-3,-4", type=str) parser.add_argument("--max_seq_length", default=128, type=int, help="The maximum total input sequence length after WordPiece tokenization. Sequences longer " "than this will be truncated, and sequences shorter than this will be padded.") - parser.add_argument("--do_lower_case", default=True, action='store_true', - help="Whether to lower case the input text. Should be True for uncased " - "models and False for cased models.") parser.add_argument("--batch_size", default=32, type=int, help="Batch size for predictions.") parser.add_argument("--local_rank", type=int, @@ -230,10 +223,7 @@ def main(): layer_indexes = [int(x) for x in args.layers.split(",")] - bert_config = BertConfig.from_json_file(args.bert_config_file) - - tokenizer = BertTokenizer( - vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) + tokenizer = BertTokenizer.from_pretrained(args.bert_model) examples = read_examples(args.input_file) @@ -244,9 +234,7 @@ def main(): for feature in features: unique_id_to_feature[feature.unique_id] = feature - model = BertModel(bert_config) - if args.init_checkpoint is not None: - model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) + model = BertModel.from_pretrained(args.bert_model) model.to(device) if args.local_rank != -1: diff --git a/examples/run_classifier.py b/examples/run_classifier.py index 9543b95a94..7ad2492b90 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -343,12 +343,9 @@ def main(): type=str, required=True, help="The input data dir. Should contain the .tsv files (or other data files) for the task.") - parser.add_argument("--bert_config_file", - default=None, - type=str, - required=True, - help="The config json file corresponding to the pre-trained BERT model. \n" - "This specifies the model architecture.") + parser.add_argument("--bert_model", default=None, type=str, required=True, + help="Bert pre-trained model selected in the list: bert-base-uncased, " + "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") parser.add_argument("--task_name", default=None, type=str, @@ -366,14 +363,6 @@ def main(): help="The output directory where the model checkpoints will be written.") ## Other parameters - parser.add_argument("--init_checkpoint", - default=None, - type=str, - help="Initial checkpoint (usually from a pre-trained BERT model).") - parser.add_argument("--do_lower_case", - default=False, - action='store_true', - help="Whether to lower case the input text. True for uncased models, False for cased models.") parser.add_argument("--max_seq_length", default=128, type=int, @@ -477,13 +466,6 @@ def main(): if not args.do_train and not args.do_eval: raise ValueError("At least one of `do_train` or `do_eval` must be True.") - bert_config = BertConfig.from_json_file(args.bert_config_file) - - if args.max_seq_length > bert_config.max_position_embeddings: - raise ValueError( - "Cannot use sequence length {} because the BERT model was only trained up to sequence length {}".format( - args.max_seq_length, bert_config.max_position_embeddings)) - if os.path.exists(args.output_dir) and os.listdir(args.output_dir): raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) os.makedirs(args.output_dir, exist_ok=True) @@ -496,8 +478,7 @@ def main(): processor = processors[task_name]() label_list = processor.get_labels() - tokenizer = BertTokenizer( - vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) + tokenizer = BertTokenizer.from_pretrained(args.bert_model) train_examples = None num_train_steps = None @@ -507,9 +488,7 @@ def main(): len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) # Prepare model - model = BertForSequenceClassification(bert_config, len(label_list)) - if args.init_checkpoint is not None: - model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) + model = BertForSequenceClassification.from_pretrained(args.bert_model, len(label_list)) if args.fp16: model.half() model.to(device) diff --git a/examples/run_squad.py b/examples/run_squad.py index 6f06fb5c64..c9acbbac7e 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -699,11 +699,9 @@ def main(): parser = argparse.ArgumentParser() ## Required parameters - parser.add_argument("--bert_config_file", default=None, type=str, required=True, - help="The config json file corresponding to the pre-trained BERT model. " - "This specifies the model architecture.") - parser.add_argument("--vocab_file", default=None, type=str, required=True, - help="The vocabulary file that the BERT model was trained on.") + parser.add_argument("--bert_model", default=None, type=str, required=True, + help="Bert pre-trained model selected in the list: bert-base-uncased, " + "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") parser.add_argument("--output_dir", default=None, type=str, required=True, help="The output directory where the model checkpoints will be written.") @@ -711,11 +709,6 @@ def main(): parser.add_argument("--train_file", default=None, type=str, help="SQuAD json for training. E.g., train-v1.1.json") parser.add_argument("--predict_file", default=None, type=str, help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json") - parser.add_argument("--init_checkpoint", default=None, type=str, - help="Initial checkpoint (usually from a pre-trained BERT model).") - parser.add_argument("--do_lower_case", default=True, action='store_true', - help="Whether to lower case the input text. Should be True for uncased " - "models and False for cased models.") parser.add_argument("--max_seq_length", default=384, type=int, help="The maximum total input sequence length after WordPiece tokenization. Sequences " "longer than this will be truncated, and sequences shorter than this will be padded.") @@ -815,20 +808,11 @@ def main(): raise ValueError( "If `do_predict` is True, then `predict_file` must be specified.") - bert_config = BertConfig.from_json_file(args.bert_config_file) - - if args.max_seq_length > bert_config.max_position_embeddings: - raise ValueError( - "Cannot use sequence length %d because the BERT model " - "was only trained up to sequence length %d" % - (args.max_seq_length, bert_config.max_position_embeddings)) - if os.path.exists(args.output_dir) and os.listdir(args.output_dir): raise ValueError("Output directory () already exists and is not empty.") os.makedirs(args.output_dir, exist_ok=True) - tokenizer = BertTokenizer( - vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) + tokenizer = BertTokenizer.from_pretrained(args.bert_model) train_examples = None num_train_steps = None @@ -839,9 +823,7 @@ def main(): len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) # Prepare model - model = BertForQuestionAnswering(bert_config) - if args.init_checkpoint is not None: - model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) + model = BertForQuestionAnswering.from_pretrained(args.bert_model) if args.fp16: model.half() model.to(device) diff --git a/setup.py b/setup.py index 95910be89f..7e28c68153 100644 --- a/setup.py +++ b/setup.py @@ -13,11 +13,11 @@ setup( url="https://github.com/huggingface/pytorch-pretrained-BERT", packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), - install_requires=['numpy', - 'torch>=0.4.1', + install_requires=['torch>=0.4.1', + 'numpy', 'boto3', - 'requests>=2.18', - 'tqdm>=4.19'], + 'requests', + 'tqdm'], scripts=["bin/pytorch_pretrained_bert"], python_requires='>=3.5.0', tests_require=['pytest'],