From e387821a96ce23f4ba57d5d11a6c328500af3f98 Mon Sep 17 00:00:00 2001 From: efsotr <104755879+efsotr@users.noreply.github.com> Date: Mon, 12 May 2025 23:45:24 +0800 Subject: [PATCH] Fix tot update in trainer (#37923) * fix total updates in epoch * add test; fix max_steps * replace with multi-gpu decorator --- src/transformers/trainer.py | 14 +++++++++----- tests/trainer/test_trainer.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 743cc353ab..ccbb4ebe44 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2495,13 +2495,13 @@ class Trainer: step = -1 epoch_iterator = iter(epoch_dataloader) # We chunkify the epoch iterator into gradient accumulation steps `n` batches - remainder = num_examples % args.gradient_accumulation_steps + remainder = steps_in_epoch % args.gradient_accumulation_steps if remainder == 0: remainder = args.gradient_accumulation_steps update_step = -1 - total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1 - if args.gradient_accumulation_steps == 1: - total_updates -= 1 + total_updates = steps_in_epoch // args.gradient_accumulation_steps + int( + remainder < args.gradient_accumulation_steps + ) for _ in range(total_updates): update_step += 1 num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder @@ -5319,7 +5319,11 @@ class Trainer: # Case 2: We have a dataloader length and can extrapolate if len_dataloader is not None: - num_update_steps_per_epoch = max(len_dataloader // args.gradient_accumulation_steps, 1) + num_update_steps_per_epoch = max( + len_dataloader // args.gradient_accumulation_steps + + int(len_dataloader % args.gradient_accumulation_steps > 0), + 1, + ) # Case 3: We have a length but are using epochs, we can extrapolate the number of steps if epoch_based: max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 278312329f..f34374a44c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -97,6 +97,7 @@ from transformers.testing_utils import ( require_torch_fp16, require_torch_gpu, require_torch_multi_accelerator, + require_torch_multi_gpu, require_torch_non_multi_accelerator, require_torch_non_multi_gpu, require_torch_tensorrt_fx, @@ -3763,6 +3764,37 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): train_output = trainer.train() self.assertEqual(train_output.global_step, int(self.n_epochs)) + @require_torch_multi_gpu + def test_num_batches_in_training_with_gradient_accumulation(self): + with tempfile.TemporaryDirectory() as tmp_dir: + for num_train_epochs in [1, 2]: + for train_len in [123, 120]: + trainer = get_regression_trainer( + train_len=train_len, + per_device_train_batch_size=4, + gradient_accumulation_steps=5, + num_train_epochs=num_train_epochs, + output_dir=tmp_dir, + ) + + total_batch_samples = [] + + def wrap_get_batch_samples(fn): + def wrapped_fn(epoch_iterator, num_batches, device): + self.assertGreater(num_batches, 0) + batch_samples, num_items_in_batch = fn(epoch_iterator, num_batches, device) + self.assertEqual(len(batch_samples), num_batches) + total_batch_samples.append(num_batches) + return batch_samples, num_items_in_batch + + return wrapped_fn + + trainer.get_batch_samples = wrap_get_batch_samples(trainer.get_batch_samples) + + trainer.train() + + self.assertEqual(len(trainer.get_train_dataloader()) * num_train_epochs, sum(total_batch_samples)) + def test_early_stopping_callback(self): # early stopping stops training before num_training_epochs with tempfile.TemporaryDirectory() as tmp_dir: