updating examples
This commit is contained in:
@@ -343,12 +343,9 @@ def main():
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
||||
parser.add_argument("--bert_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The config json file corresponding to the pre-trained BERT model. \n"
|
||||
"This specifies the model architecture.")
|
||||
parser.add_argument("--bert_model", default=None, type=str, required=True,
|
||||
help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
||||
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
|
||||
parser.add_argument("--task_name",
|
||||
default=None,
|
||||
type=str,
|
||||
@@ -366,14 +363,6 @@ def main():
|
||||
help="The output directory where the model checkpoints will be written.")
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--init_checkpoint",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Initial checkpoint (usually from a pre-trained BERT model).")
|
||||
parser.add_argument("--do_lower_case",
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="Whether to lower case the input text. True for uncased models, False for cased models.")
|
||||
parser.add_argument("--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
@@ -477,13 +466,6 @@ def main():
|
||||
if not args.do_train and not args.do_eval:
|
||||
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
||||
|
||||
bert_config = BertConfig.from_json_file(args.bert_config_file)
|
||||
|
||||
if args.max_seq_length > bert_config.max_position_embeddings:
|
||||
raise ValueError(
|
||||
"Cannot use sequence length {} because the BERT model was only trained up to sequence length {}".format(
|
||||
args.max_seq_length, bert_config.max_position_embeddings))
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
@@ -496,8 +478,7 @@ def main():
|
||||
processor = processors[task_name]()
|
||||
label_list = processor.get_labels()
|
||||
|
||||
tokenizer = BertTokenizer(
|
||||
vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model)
|
||||
|
||||
train_examples = None
|
||||
num_train_steps = None
|
||||
@@ -507,9 +488,7 @@ def main():
|
||||
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
|
||||
|
||||
# Prepare model
|
||||
model = BertForSequenceClassification(bert_config, len(label_list))
|
||||
if args.init_checkpoint is not None:
|
||||
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
||||
model = BertForSequenceClassification.from_pretrained(args.bert_model, len(label_list))
|
||||
if args.fp16:
|
||||
model.half()
|
||||
model.to(device)
|
||||
|
||||
Reference in New Issue
Block a user