Fix small bug in run_squad_pytorch.py

This commit is contained in:
VictorSanh
2018-11-02 03:32:35 -04:00
parent 98b9771dfe
commit 62ac7e9a60

View File

@@ -27,6 +27,7 @@ import tokenization
import six import six
import argparse import argparse
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
@@ -103,6 +104,10 @@ parser.add_argument("--max_answer_length", default=30, type=int,
parser.add_argument("--verbose_logging", default=False, type=bool, parser.add_argument("--verbose_logging", default=False, type=bool,
help="If true, all of the warnings related to data processing will be printed. " help="If true, all of the warnings related to data processing will be printed. "
"A number of warnings are expected for a normal SQuAD evaluation.") "A number of warnings are expected for a normal SQuAD evaluation.")
parser.add_argument("--no_cuda",
default = False,
action='store_true',
help = "Whether not to use CUDA when available")
parser.add_argument("--local_rank", parser.add_argument("--local_rank",
type=int, type=int,
default=-1, default=-1,
@@ -769,8 +774,7 @@ def main():
(args.max_seq_length, bert_config.max_position_embeddings)) (args.max_seq_length, bert_config.max_position_embeddings))
if os.path.exists(args.output_dir) and os.listdir(args.output_dir): if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
raise ValueError(f"Output directory ({args.output_dir}) already exists and is " raise ValueError("Output directory () already exists and is not empty.")
f"not empty.")
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
tokenizer = tokenization.FullTokenizer( tokenizer = tokenization.FullTokenizer(
@@ -795,7 +799,8 @@ def main():
lr=args.learning_rate, schedule='warmup_linear', lr=args.learning_rate, schedule='warmup_linear',
warmup=args.warmup_proportion, warmup=args.warmup_proportion,
t_total=num_train_steps) t_total=num_train_steps)
global_step = 0
if args.do_train: if args.do_train:
train_features = convert_examples_to_features( train_features = convert_examples_to_features(
examples=train_examples, examples=train_examples,
@@ -823,7 +828,7 @@ def main():
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
model.train() model.train()
for epoch in args.num_train_epochs: for epoch in range(int(args.num_train_epochs)):
for input_ids, input_mask, segment_ids, label_ids in train_dataloader: for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device) input_mask = input_mask.float().to(device)