model training loop working – still have to check that everything is exactly same

This commit is contained in:
thomwolf
2018-11-02 01:31:31 +01:00
parent f690f0e167
commit 9343a2311b
2 changed files with 37 additions and 34 deletions

View File

@@ -484,7 +484,7 @@ def main():
num_train_steps = int(
len(train_examples) / args.train_batch_size * args.num_train_epochs)
model = BertForSequenceClassification(bert_config)
model = BertForSequenceClassification(bert_config, len(label_list))
if args.init_checkpoint is not None:
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model.to(device)
@@ -504,10 +504,10 @@ def main():
logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_steps)
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.Long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.Long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.Long)
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.Long)
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if args.local_rank == -1:
@@ -519,12 +519,12 @@ def main():
model.train()
global_step = 0
for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
input_ids.to(device)
input_mask.to(device)
segment_ids.to(device)
label_ids.to(device)
input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device)
segment_ids = segment_ids.to(device)
label_ids = label_ids.to(device)
loss = model(input_ids, segment_ids, input_mask, label_ids)
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
loss.backward()
optimizer.step()
global_step += 1
@@ -538,10 +538,10 @@ def main():
logger.info(" Num examples = %d", len(eval_examples))
logger.info(" Batch size = %d", args.eval_batch_size)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.Long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.Long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.Long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.Long)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if args.local_rank == -1:
@@ -554,10 +554,10 @@ def main():
eval_loss = 0
eval_accuracy = 0
for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
input_ids.to(device)
input_mask.to(device)
segment_ids.to(device)
label_ids.to(device)
input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device)
segment_ids = segment_ids.to(device)
label_ids = label_ids.to(device)
tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids)
tmp_eval_accuracy = accuracy(logits, label_ids)