Fix tr_loss rescaling factor using global_step

This commit is contained in:
Mathieu Prouveur
2019-04-29 12:58:29 +02:00
parent ed8fad7390
commit 87b9ec3843
2 changed files with 5 additions and 5 deletions

View File

@@ -452,7 +452,7 @@ def main():
loss = loss * args.loss_scale
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
tr_loss += loss.item() * args.gradient_accumulation_steps
tr_loss += loss.item()
nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1
@@ -537,7 +537,7 @@ def main():
result = {'eval_loss': eval_loss,
'eval_accuracy': eval_accuracy,
'global_step': global_step,
'loss': tr_loss/nb_tr_steps}
'loss': tr_loss/global_step}
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer: