updating examples
This commit is contained in:
@@ -193,23 +193,16 @@ def main():
|
|||||||
|
|
||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--input_file", default=None, type=str, required=True)
|
parser.add_argument("--input_file", default=None, type=str, required=True)
|
||||||
parser.add_argument("--vocab_file", default=None, type=str, required=True,
|
|
||||||
help="The vocabulary file that the BERT model was trained on.")
|
|
||||||
parser.add_argument("--output_file", default=None, type=str, required=True)
|
parser.add_argument("--output_file", default=None, type=str, required=True)
|
||||||
parser.add_argument("--bert_config_file", default=None, type=str, required=True,
|
parser.add_argument("--bert_model", default=None, type=str, required=True,
|
||||||
help="The config json file corresponding to the pre-trained BERT model. "
|
help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
||||||
"This specifies the model architecture.")
|
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
|
||||||
parser.add_argument("--init_checkpoint", default=None, type=str, required=True,
|
|
||||||
help="Initial checkpoint (usually from a pre-trained BERT model).")
|
|
||||||
|
|
||||||
## Other parameters
|
## Other parameters
|
||||||
parser.add_argument("--layers", default="-1,-2,-3,-4", type=str)
|
parser.add_argument("--layers", default="-1,-2,-3,-4", type=str)
|
||||||
parser.add_argument("--max_seq_length", default=128, type=int,
|
parser.add_argument("--max_seq_length", default=128, type=int,
|
||||||
help="The maximum total input sequence length after WordPiece tokenization. Sequences longer "
|
help="The maximum total input sequence length after WordPiece tokenization. Sequences longer "
|
||||||
"than this will be truncated, and sequences shorter than this will be padded.")
|
"than this will be truncated, and sequences shorter than this will be padded.")
|
||||||
parser.add_argument("--do_lower_case", default=True, action='store_true',
|
|
||||||
help="Whether to lower case the input text. Should be True for uncased "
|
|
||||||
"models and False for cased models.")
|
|
||||||
parser.add_argument("--batch_size", default=32, type=int, help="Batch size for predictions.")
|
parser.add_argument("--batch_size", default=32, type=int, help="Batch size for predictions.")
|
||||||
parser.add_argument("--local_rank",
|
parser.add_argument("--local_rank",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -230,10 +223,7 @@ def main():
|
|||||||
|
|
||||||
layer_indexes = [int(x) for x in args.layers.split(",")]
|
layer_indexes = [int(x) for x in args.layers.split(",")]
|
||||||
|
|
||||||
bert_config = BertConfig.from_json_file(args.bert_config_file)
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model)
|
||||||
|
|
||||||
tokenizer = BertTokenizer(
|
|
||||||
vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
|
|
||||||
|
|
||||||
examples = read_examples(args.input_file)
|
examples = read_examples(args.input_file)
|
||||||
|
|
||||||
@@ -244,9 +234,7 @@ def main():
|
|||||||
for feature in features:
|
for feature in features:
|
||||||
unique_id_to_feature[feature.unique_id] = feature
|
unique_id_to_feature[feature.unique_id] = feature
|
||||||
|
|
||||||
model = BertModel(bert_config)
|
model = BertModel.from_pretrained(args.bert_model)
|
||||||
if args.init_checkpoint is not None:
|
|
||||||
model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
|
|||||||
@@ -343,12 +343,9 @@ def main():
|
|||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
||||||
parser.add_argument("--bert_config_file",
|
parser.add_argument("--bert_model", default=None, type=str, required=True,
|
||||||
default=None,
|
help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
||||||
type=str,
|
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
|
||||||
required=True,
|
|
||||||
help="The config json file corresponding to the pre-trained BERT model. \n"
|
|
||||||
"This specifies the model architecture.")
|
|
||||||
parser.add_argument("--task_name",
|
parser.add_argument("--task_name",
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
@@ -366,14 +363,6 @@ def main():
|
|||||||
help="The output directory where the model checkpoints will be written.")
|
help="The output directory where the model checkpoints will be written.")
|
||||||
|
|
||||||
## Other parameters
|
## Other parameters
|
||||||
parser.add_argument("--init_checkpoint",
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
help="Initial checkpoint (usually from a pre-trained BERT model).")
|
|
||||||
parser.add_argument("--do_lower_case",
|
|
||||||
default=False,
|
|
||||||
action='store_true',
|
|
||||||
help="Whether to lower case the input text. True for uncased models, False for cased models.")
|
|
||||||
parser.add_argument("--max_seq_length",
|
parser.add_argument("--max_seq_length",
|
||||||
default=128,
|
default=128,
|
||||||
type=int,
|
type=int,
|
||||||
@@ -477,13 +466,6 @@ def main():
|
|||||||
if not args.do_train and not args.do_eval:
|
if not args.do_train and not args.do_eval:
|
||||||
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
||||||
|
|
||||||
bert_config = BertConfig.from_json_file(args.bert_config_file)
|
|
||||||
|
|
||||||
if args.max_seq_length > bert_config.max_position_embeddings:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot use sequence length {} because the BERT model was only trained up to sequence length {}".format(
|
|
||||||
args.max_seq_length, bert_config.max_position_embeddings))
|
|
||||||
|
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
||||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
@@ -496,8 +478,7 @@ def main():
|
|||||||
processor = processors[task_name]()
|
processor = processors[task_name]()
|
||||||
label_list = processor.get_labels()
|
label_list = processor.get_labels()
|
||||||
|
|
||||||
tokenizer = BertTokenizer(
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model)
|
||||||
vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
|
|
||||||
|
|
||||||
train_examples = None
|
train_examples = None
|
||||||
num_train_steps = None
|
num_train_steps = None
|
||||||
@@ -507,9 +488,7 @@ def main():
|
|||||||
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
|
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
|
||||||
|
|
||||||
# Prepare model
|
# Prepare model
|
||||||
model = BertForSequenceClassification(bert_config, len(label_list))
|
model = BertForSequenceClassification.from_pretrained(args.bert_model, len(label_list))
|
||||||
if args.init_checkpoint is not None:
|
|
||||||
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
model.half()
|
model.half()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|||||||
@@ -699,11 +699,9 @@ def main():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--bert_config_file", default=None, type=str, required=True,
|
parser.add_argument("--bert_model", default=None, type=str, required=True,
|
||||||
help="The config json file corresponding to the pre-trained BERT model. "
|
help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
||||||
"This specifies the model architecture.")
|
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
|
||||||
parser.add_argument("--vocab_file", default=None, type=str, required=True,
|
|
||||||
help="The vocabulary file that the BERT model was trained on.")
|
|
||||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||||
help="The output directory where the model checkpoints will be written.")
|
help="The output directory where the model checkpoints will be written.")
|
||||||
|
|
||||||
@@ -711,11 +709,6 @@ def main():
|
|||||||
parser.add_argument("--train_file", default=None, type=str, help="SQuAD json for training. E.g., train-v1.1.json")
|
parser.add_argument("--train_file", default=None, type=str, help="SQuAD json for training. E.g., train-v1.1.json")
|
||||||
parser.add_argument("--predict_file", default=None, type=str,
|
parser.add_argument("--predict_file", default=None, type=str,
|
||||||
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
|
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
|
||||||
parser.add_argument("--init_checkpoint", default=None, type=str,
|
|
||||||
help="Initial checkpoint (usually from a pre-trained BERT model).")
|
|
||||||
parser.add_argument("--do_lower_case", default=True, action='store_true',
|
|
||||||
help="Whether to lower case the input text. Should be True for uncased "
|
|
||||||
"models and False for cased models.")
|
|
||||||
parser.add_argument("--max_seq_length", default=384, type=int,
|
parser.add_argument("--max_seq_length", default=384, type=int,
|
||||||
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
|
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
|
||||||
"longer than this will be truncated, and sequences shorter than this will be padded.")
|
"longer than this will be truncated, and sequences shorter than this will be padded.")
|
||||||
@@ -815,20 +808,11 @@ def main():
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"If `do_predict` is True, then `predict_file` must be specified.")
|
"If `do_predict` is True, then `predict_file` must be specified.")
|
||||||
|
|
||||||
bert_config = BertConfig.from_json_file(args.bert_config_file)
|
|
||||||
|
|
||||||
if args.max_seq_length > bert_config.max_position_embeddings:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot use sequence length %d because the BERT model "
|
|
||||||
"was only trained up to sequence length %d" %
|
|
||||||
(args.max_seq_length, bert_config.max_position_embeddings))
|
|
||||||
|
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
||||||
raise ValueError("Output directory () already exists and is not empty.")
|
raise ValueError("Output directory () already exists and is not empty.")
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
tokenizer = BertTokenizer(
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model)
|
||||||
vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
|
|
||||||
|
|
||||||
train_examples = None
|
train_examples = None
|
||||||
num_train_steps = None
|
num_train_steps = None
|
||||||
@@ -839,9 +823,7 @@ def main():
|
|||||||
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
|
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
|
||||||
|
|
||||||
# Prepare model
|
# Prepare model
|
||||||
model = BertForQuestionAnswering(bert_config)
|
model = BertForQuestionAnswering.from_pretrained(args.bert_model)
|
||||||
if args.init_checkpoint is not None:
|
|
||||||
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
model.half()
|
model.half()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|||||||
8
setup.py
8
setup.py
@@ -13,11 +13,11 @@ setup(
|
|||||||
url="https://github.com/huggingface/pytorch-pretrained-BERT",
|
url="https://github.com/huggingface/pytorch-pretrained-BERT",
|
||||||
packages=find_packages(exclude=["*.tests", "*.tests.*",
|
packages=find_packages(exclude=["*.tests", "*.tests.*",
|
||||||
"tests.*", "tests"]),
|
"tests.*", "tests"]),
|
||||||
install_requires=['numpy',
|
install_requires=['torch>=0.4.1',
|
||||||
'torch>=0.4.1',
|
'numpy',
|
||||||
'boto3',
|
'boto3',
|
||||||
'requests>=2.18',
|
'requests',
|
||||||
'tqdm>=4.19'],
|
'tqdm'],
|
||||||
scripts=["bin/pytorch_pretrained_bert"],
|
scripts=["bin/pytorch_pretrained_bert"],
|
||||||
python_requires='>=3.5.0',
|
python_requires='>=3.5.0',
|
||||||
tests_require=['pytest'],
|
tests_require=['pytest'],
|
||||||
|
|||||||
Reference in New Issue
Block a user