Fix tot update in trainer (#37923)

* fix total updates in epoch

* add test; fix max_steps

* replace with multi-gpu decorator
This commit is contained in:
efsotr
2025-05-12 23:45:24 +08:00
committed by GitHub
parent f0e975c6cf
commit e387821a96
2 changed files with 41 additions and 5 deletions

View File

@@ -97,6 +97,7 @@ from transformers.testing_utils import (
require_torch_fp16,
require_torch_gpu,
require_torch_multi_accelerator,
require_torch_multi_gpu,
require_torch_non_multi_accelerator,
require_torch_non_multi_gpu,
require_torch_tensorrt_fx,
@@ -3763,6 +3764,37 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
train_output = trainer.train()
self.assertEqual(train_output.global_step, int(self.n_epochs))
@require_torch_multi_gpu
def test_num_batches_in_training_with_gradient_accumulation(self):
with tempfile.TemporaryDirectory() as tmp_dir:
for num_train_epochs in [1, 2]:
for train_len in [123, 120]:
trainer = get_regression_trainer(
train_len=train_len,
per_device_train_batch_size=4,
gradient_accumulation_steps=5,
num_train_epochs=num_train_epochs,
output_dir=tmp_dir,
)
total_batch_samples = []
def wrap_get_batch_samples(fn):
def wrapped_fn(epoch_iterator, num_batches, device):
self.assertGreater(num_batches, 0)
batch_samples, num_items_in_batch = fn(epoch_iterator, num_batches, device)
self.assertEqual(len(batch_samples), num_batches)
total_batch_samples.append(num_batches)
return batch_samples, num_items_in_batch
return wrapped_fn
trainer.get_batch_samples = wrap_get_batch_samples(trainer.get_batch_samples)
trainer.train()
self.assertEqual(len(trainer.get_train_dataloader()) * num_train_epochs, sum(total_batch_samples))
def test_early_stopping_callback(self):
# early stopping stops training before num_training_epochs
with tempfile.TemporaryDirectory() as tmp_dir: