small tweaks

This commit is contained in:
thomwolf
2018-11-02 01:38:22 +01:00
parent 9343a2311b
commit 2c731fd129
2 changed files with 32 additions and 25 deletions

View File

@@ -115,16 +115,10 @@ parser.add_argument("--save_checkpoints_steps",
default = 1000,
type = int,
help = "How often to save the model checkpoint.")
parser.add_argument("--iterations_per_loop",
default = 1000,
type = int,
help = "How many steps to make in each estimator call.")
parser.add_argument("--no_cuda",
default = False,
type = bool,
help = "Whether not to use CUDA when available")
parser.add_argument("--local_rank",
type=int,
default=-1,
@@ -518,16 +512,17 @@ def main():
model.train()
global_step = 0
for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device)
segment_ids = segment_ids.to(device)
label_ids = label_ids.to(device)
for epoch in args.num_train_epochs:
for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device)
segment_ids = segment_ids.to(device)
label_ids = label_ids.to(device)
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
loss.backward()
optimizer.step()
global_step += 1
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
loss.backward()
optimizer.step()
global_step += 1
if args.do_eval:
eval_examples = processor.get_dev_examples(args.data_dir)