Update example files so that tr_loss is not affected by args.gradient_accumulation_step
This commit is contained in:
@@ -845,7 +845,7 @@ def main():
|
|||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
tr_loss += loss.item()
|
tr_loss += loss.item() * args.gradient_accumulation_steps
|
||||||
nb_tr_examples += input_ids.size(0)
|
nb_tr_examples += input_ids.size(0)
|
||||||
nb_tr_steps += 1
|
nb_tr_steps += 1
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
|
|||||||
@@ -452,7 +452,7 @@ def main():
|
|||||||
loss = loss * args.loss_scale
|
loss = loss * args.loss_scale
|
||||||
if args.gradient_accumulation_steps > 1:
|
if args.gradient_accumulation_steps > 1:
|
||||||
loss = loss / args.gradient_accumulation_steps
|
loss = loss / args.gradient_accumulation_steps
|
||||||
tr_loss += loss.item()
|
tr_loss += loss.item() * args.gradient_accumulation_steps
|
||||||
nb_tr_examples += input_ids.size(0)
|
nb_tr_examples += input_ids.size(0)
|
||||||
nb_tr_steps += 1
|
nb_tr_steps += 1
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user