This commit is contained in:
Ilyas Moutawwakil
2025-06-23 10:56:51 +02:00
committed by GitHub
parent 2166b6b4ff
commit 984ff89e73
16 changed files with 618 additions and 14 deletions

View File

@@ -22,6 +22,7 @@ from transformers.testing_utils import (
execute_subprocess_async,
get_torch_dist_unique_port,
require_torch_multi_accelerator,
run_first,
torch_device,
)
from transformers.training_args import ParallelMode
@@ -116,6 +117,7 @@ if is_torch_available():
class TestTrainerDistributed(TestCasePlus):
@run_first
@require_torch_multi_accelerator
def test_trainer(self):
distributed_args = f"""--nproc_per_node={backend_device_count(torch_device)}
@@ -199,8 +201,7 @@ if __name__ == "__main__":
model = RegressionModel()
training_args.per_device_train_batch_size = 1
training_args.max_steps = 1
training_args.accelerator_config = {
"dispatch_batches": False,
}
training_args.accelerator_config.dispatch_batches = False
trainer = Trainer(model, training_args, train_dataset=train_dataset)
trainer.train()