add best steps to train
This commit is contained in:
@@ -127,6 +127,7 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
global_step = 0
|
global_step = 0
|
||||||
tr_loss, logging_loss = 0.0, 0.0
|
tr_loss, logging_loss = 0.0, 0.0
|
||||||
best_dev_acc, best_dev_loss = 0.0, 99999999999.0
|
best_dev_acc, best_dev_loss = 0.0, 99999999999.0
|
||||||
|
best_steps = 0
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
||||||
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
|
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
|
||||||
@@ -171,6 +172,8 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
if results["eval_loss"] < best_dev_loss:
|
if results["eval_loss"] < best_dev_loss:
|
||||||
best_dev_acc = results["eval_acc"]
|
best_dev_acc = results["eval_acc"]
|
||||||
best_dev_loss = results["eval_loss"]
|
best_dev_loss = results["eval_loss"]
|
||||||
|
best_steps = global_step
|
||||||
|
if args.do_test:
|
||||||
results_test = evaluate(args, model, tokenizer, test=True)
|
results_test = evaluate(args, model, tokenizer, test=True)
|
||||||
for key, value in results_test.items():
|
for key, value in results_test.items():
|
||||||
tb_writer.add_scalar('test_{}'.format(key), value, global_step)
|
tb_writer.add_scalar('test_{}'.format(key), value, global_step)
|
||||||
@@ -201,7 +204,7 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
tb_writer.close()
|
tb_writer.close()
|
||||||
|
|
||||||
return global_step, tr_loss / global_step
|
return global_step, tr_loss / global_step, best_steps
|
||||||
|
|
||||||
|
|
||||||
def evaluate(args, model, tokenizer, prefix="", test=False):
|
def evaluate(args, model, tokenizer, prefix="", test=False):
|
||||||
@@ -471,7 +474,7 @@ def main():
|
|||||||
# Training
|
# Training
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
|
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
|
||||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
global_step, tr_loss, _ = train(args, train_dataset, model, tokenizer)
|
||||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user