Fix tot update in trainer (#37923)
* fix total updates in epoch * add test; fix max_steps * replace with multi-gpu decorator
This commit is contained in:
@@ -2495,13 +2495,13 @@ class Trainer:
|
|||||||
step = -1
|
step = -1
|
||||||
epoch_iterator = iter(epoch_dataloader)
|
epoch_iterator = iter(epoch_dataloader)
|
||||||
# We chunkify the epoch iterator into gradient accumulation steps `n` batches
|
# We chunkify the epoch iterator into gradient accumulation steps `n` batches
|
||||||
remainder = num_examples % args.gradient_accumulation_steps
|
remainder = steps_in_epoch % args.gradient_accumulation_steps
|
||||||
if remainder == 0:
|
if remainder == 0:
|
||||||
remainder = args.gradient_accumulation_steps
|
remainder = args.gradient_accumulation_steps
|
||||||
update_step = -1
|
update_step = -1
|
||||||
total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1
|
total_updates = steps_in_epoch // args.gradient_accumulation_steps + int(
|
||||||
if args.gradient_accumulation_steps == 1:
|
remainder < args.gradient_accumulation_steps
|
||||||
total_updates -= 1
|
)
|
||||||
for _ in range(total_updates):
|
for _ in range(total_updates):
|
||||||
update_step += 1
|
update_step += 1
|
||||||
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
|
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
|
||||||
@@ -5319,7 +5319,11 @@ class Trainer:
|
|||||||
|
|
||||||
# Case 2: We have a dataloader length and can extrapolate
|
# Case 2: We have a dataloader length and can extrapolate
|
||||||
if len_dataloader is not None:
|
if len_dataloader is not None:
|
||||||
num_update_steps_per_epoch = max(len_dataloader // args.gradient_accumulation_steps, 1)
|
num_update_steps_per_epoch = max(
|
||||||
|
len_dataloader // args.gradient_accumulation_steps
|
||||||
|
+ int(len_dataloader % args.gradient_accumulation_steps > 0),
|
||||||
|
1,
|
||||||
|
)
|
||||||
# Case 3: We have a length but are using epochs, we can extrapolate the number of steps
|
# Case 3: We have a length but are using epochs, we can extrapolate the number of steps
|
||||||
if epoch_based:
|
if epoch_based:
|
||||||
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
|
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
|
||||||
|
|||||||
@@ -97,6 +97,7 @@ from transformers.testing_utils import (
|
|||||||
require_torch_fp16,
|
require_torch_fp16,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
require_torch_multi_accelerator,
|
require_torch_multi_accelerator,
|
||||||
|
require_torch_multi_gpu,
|
||||||
require_torch_non_multi_accelerator,
|
require_torch_non_multi_accelerator,
|
||||||
require_torch_non_multi_gpu,
|
require_torch_non_multi_gpu,
|
||||||
require_torch_tensorrt_fx,
|
require_torch_tensorrt_fx,
|
||||||
@@ -3763,6 +3764,37 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
train_output = trainer.train()
|
train_output = trainer.train()
|
||||||
self.assertEqual(train_output.global_step, int(self.n_epochs))
|
self.assertEqual(train_output.global_step, int(self.n_epochs))
|
||||||
|
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
def test_num_batches_in_training_with_gradient_accumulation(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
for num_train_epochs in [1, 2]:
|
||||||
|
for train_len in [123, 120]:
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
train_len=train_len,
|
||||||
|
per_device_train_batch_size=4,
|
||||||
|
gradient_accumulation_steps=5,
|
||||||
|
num_train_epochs=num_train_epochs,
|
||||||
|
output_dir=tmp_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
total_batch_samples = []
|
||||||
|
|
||||||
|
def wrap_get_batch_samples(fn):
|
||||||
|
def wrapped_fn(epoch_iterator, num_batches, device):
|
||||||
|
self.assertGreater(num_batches, 0)
|
||||||
|
batch_samples, num_items_in_batch = fn(epoch_iterator, num_batches, device)
|
||||||
|
self.assertEqual(len(batch_samples), num_batches)
|
||||||
|
total_batch_samples.append(num_batches)
|
||||||
|
return batch_samples, num_items_in_batch
|
||||||
|
|
||||||
|
return wrapped_fn
|
||||||
|
|
||||||
|
trainer.get_batch_samples = wrap_get_batch_samples(trainer.get_batch_samples)
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
self.assertEqual(len(trainer.get_train_dataloader()) * num_train_epochs, sum(total_batch_samples))
|
||||||
|
|
||||||
def test_early_stopping_callback(self):
|
def test_early_stopping_callback(self):
|
||||||
# early stopping stops training before num_training_epochs
|
# early stopping stops training before num_training_epochs
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
|||||||
Reference in New Issue
Block a user