Corrected the displayed loss when gradient_accumulation_steps > 1
This commit is contained in:
@@ -309,7 +309,7 @@ def main():
|
|||||||
nb_tr_examples += input_ids.size(0)
|
nb_tr_examples += input_ids.size(0)
|
||||||
nb_tr_steps += 1
|
nb_tr_steps += 1
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
mean_loss = tr_loss / nb_tr_steps
|
mean_loss = tr_loss * args.gradient_accumulation_steps / nb_tr_steps
|
||||||
pbar.set_postfix_str(f"Loss: {mean_loss:.5f}")
|
pbar.set_postfix_str(f"Loss: {mean_loss:.5f}")
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
|
|||||||
Reference in New Issue
Block a user