Fix small bug in run_squad_pytorch.py
This commit is contained in:
@@ -27,6 +27,7 @@ import tokenization
|
|||||||
import six
|
import six
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
@@ -103,6 +104,10 @@ parser.add_argument("--max_answer_length", default=30, type=int,
|
|||||||
parser.add_argument("--verbose_logging", default=False, type=bool,
|
parser.add_argument("--verbose_logging", default=False, type=bool,
|
||||||
help="If true, all of the warnings related to data processing will be printed. "
|
help="If true, all of the warnings related to data processing will be printed. "
|
||||||
"A number of warnings are expected for a normal SQuAD evaluation.")
|
"A number of warnings are expected for a normal SQuAD evaluation.")
|
||||||
|
parser.add_argument("--no_cuda",
|
||||||
|
default = False,
|
||||||
|
action='store_true',
|
||||||
|
help = "Whether not to use CUDA when available")
|
||||||
parser.add_argument("--local_rank",
|
parser.add_argument("--local_rank",
|
||||||
type=int,
|
type=int,
|
||||||
default=-1,
|
default=-1,
|
||||||
@@ -769,8 +774,7 @@ def main():
|
|||||||
(args.max_seq_length, bert_config.max_position_embeddings))
|
(args.max_seq_length, bert_config.max_position_embeddings))
|
||||||
|
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
||||||
raise ValueError(f"Output directory ({args.output_dir}) already exists and is "
|
raise ValueError("Output directory () already exists and is not empty.")
|
||||||
f"not empty.")
|
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
tokenizer = tokenization.FullTokenizer(
|
tokenizer = tokenization.FullTokenizer(
|
||||||
@@ -796,6 +800,7 @@ def main():
|
|||||||
warmup=args.warmup_proportion,
|
warmup=args.warmup_proportion,
|
||||||
t_total=num_train_steps)
|
t_total=num_train_steps)
|
||||||
|
|
||||||
|
global_step = 0
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
train_features = convert_examples_to_features(
|
train_features = convert_examples_to_features(
|
||||||
examples=train_examples,
|
examples=train_examples,
|
||||||
@@ -823,7 +828,7 @@ def main():
|
|||||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
for epoch in args.num_train_epochs:
|
for epoch in range(int(args.num_train_epochs)):
|
||||||
for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
|
for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
|
||||||
input_ids = input_ids.to(device)
|
input_ids = input_ids.to(device)
|
||||||
input_mask = input_mask.float().to(device)
|
input_mask = input_mask.float().to(device)
|
||||||
|
|||||||
Reference in New Issue
Block a user