Introduce AcceleratorConfig dataclass (#28664)

* Introduce acceleratorconfig dataclass

* Extra second warn

* Move import

* Try moving import under is_accelerate_available

* Quality

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Clean

* Remove to_kwargs

* Change version

* Improve tests by including dispatch and split batches

* Improve reliability

* Update tests/trainer/test_trainer.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Fixup tests and review nits

* Make tests pass

* protect import

* Protect import

* Empty-Commit

* Make training_args.to_dict handle the AcceleratorConfig

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Zach Mueller
2024-02-14 10:18:09 -05:00
committed by GitHub
parent 69ca640dd6
commit 0507e69d34
4 changed files with 307 additions and 14 deletions

View File

@@ -118,6 +118,7 @@ if is_torch_available():
TrainerState,
)
from transformers.modeling_utils import unwrap_model
from transformers.trainer_pt_utils import AcceleratorConfig
if is_safetensors_available():
import safetensors.torch
@@ -2412,6 +2413,146 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
execute_subprocess_async(command)
# successful return here == success - any errors would have caused an error or a timeout in the sub-call
def test_accelerator_config_empty(self):
# Checks that a config can be made with the defaults if not passed
with tempfile.TemporaryDirectory() as tmp_dir:
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
# Leaves one option as something *not* basic
args = RegressionTrainingArguments(
output_dir=tmp_dir,
)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.split_batches, False)
self.assertEqual(trainer.accelerator.dispatch_batches, None)
self.assertEqual(trainer.accelerator.even_batches, True)
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
def test_accelerator_config_from_dict(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
with tempfile.TemporaryDirectory() as tmp_dir:
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
# Leaves all options as something *not* basic
args = RegressionTrainingArguments(
output_dir=tmp_dir,
accelerator_config={
"split_batches": True,
"dispatch_batches": True,
"even_batches": False,
"use_seedable_sampler": True,
},
)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.split_batches, True)
self.assertEqual(trainer.accelerator.dispatch_batches, True)
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
def test_accelerator_config_from_yaml(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
with tempfile.TemporaryDirectory() as tmp_dir:
path_file = Path(tmp_dir) / "accelerator_config.json"
with open(path_file, "w") as f:
accelerator_config = {
"split_batches": True,
"dispatch_batches": True,
"even_batches": False,
"use_seedable_sampler": False,
}
json.dump(accelerator_config, f)
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
# Leaves all options as something *not* basic
args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=path_file)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.split_batches, True)
self.assertEqual(trainer.accelerator.dispatch_batches, True)
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
def test_accelerator_config_from_dataclass(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
accelerator_config = AcceleratorConfig(
split_batches=True, dispatch_batches=True, even_batches=False, use_seedable_sampler=False
)
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
with tempfile.TemporaryDirectory() as tmp_dir:
args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.split_batches, True)
self.assertEqual(trainer.accelerator.dispatch_batches, True)
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
def test_accelerator_config_from_partial(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
with tempfile.TemporaryDirectory() as tmp_dir:
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
# Leaves one option as something *not* basic
args = RegressionTrainingArguments(
output_dir=tmp_dir,
accelerator_config={
"split_batches": True,
},
)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.split_batches, True)
self.assertEqual(trainer.accelerator.dispatch_batches, None)
self.assertEqual(trainer.accelerator.even_batches, True)
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
def test_accelerator_config_from_dict_with_deprecated_args(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
# and maintains the deprecated args if passed in
with tempfile.TemporaryDirectory() as tmp_dir:
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
# Leaves all options as something *not* basic
with self.assertWarns(FutureWarning) as cm:
args = RegressionTrainingArguments(
output_dir=tmp_dir,
accelerator_config={
"split_batches": True,
},
dispatch_batches=False,
)
self.assertIn("dispatch_batches", str(cm.warnings[0].message))
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.dispatch_batches, False)
self.assertEqual(trainer.accelerator.split_batches, True)
with self.assertWarns(FutureWarning) as cm:
args = RegressionTrainingArguments(
output_dir=tmp_dir,
accelerator_config={
"even_batches": False,
},
split_batches=True,
)
self.assertIn("split_batches", str(cm.warnings[0].message))
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.split_batches, True)
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.dispatch_batches, None)
@require_torch
@is_staging_test