Fix small bug in run_squad_pytorch.py
This commit is contained in:
@@ -27,6 +27,7 @@ import tokenization
|
||||
import six
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
@@ -103,6 +104,10 @@ parser.add_argument("--max_answer_length", default=30, type=int,
|
||||
parser.add_argument("--verbose_logging", default=False, type=bool,
|
||||
help="If true, all of the warnings related to data processing will be printed. "
|
||||
"A number of warnings are expected for a normal SQuAD evaluation.")
|
||||
parser.add_argument("--no_cuda",
|
||||
default = False,
|
||||
action='store_true',
|
||||
help = "Whether not to use CUDA when available")
|
||||
parser.add_argument("--local_rank",
|
||||
type=int,
|
||||
default=-1,
|
||||
@@ -769,8 +774,7 @@ def main():
|
||||
(args.max_seq_length, bert_config.max_position_embeddings))
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
||||
raise ValueError(f"Output directory ({args.output_dir}) already exists and is "
|
||||
f"not empty.")
|
||||
raise ValueError("Output directory () already exists and is not empty.")
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
tokenizer = tokenization.FullTokenizer(
|
||||
@@ -795,7 +799,8 @@ def main():
|
||||
lr=args.learning_rate, schedule='warmup_linear',
|
||||
warmup=args.warmup_proportion,
|
||||
t_total=num_train_steps)
|
||||
|
||||
|
||||
global_step = 0
|
||||
if args.do_train:
|
||||
train_features = convert_examples_to_features(
|
||||
examples=train_examples,
|
||||
@@ -823,7 +828,7 @@ def main():
|
||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
model.train()
|
||||
for epoch in args.num_train_epochs:
|
||||
for epoch in range(int(args.num_train_epochs)):
|
||||
for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.float().to(device)
|
||||
|
||||
Reference in New Issue
Block a user