Introduce AcceleratorConfig dataclass (#28664)
* Introduce acceleratorconfig dataclass * Extra second warn * Move import * Try moving import under is_accelerate_available * Quality * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Clean * Remove to_kwargs * Change version * Improve tests by including dispatch and split batches * Improve reliability * Update tests/trainer/test_trainer.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Fixup tests and review nits * Make tests pass * protect import * Protect import * Empty-Commit * Make training_args.to_dict handle the AcceleratorConfig --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -118,6 +118,7 @@ if is_torch_available():
|
||||
TrainerState,
|
||||
)
|
||||
from transformers.modeling_utils import unwrap_model
|
||||
from transformers.trainer_pt_utils import AcceleratorConfig
|
||||
|
||||
if is_safetensors_available():
|
||||
import safetensors.torch
|
||||
@@ -2412,6 +2413,146 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
execute_subprocess_async(command)
|
||||
# successful return here == success - any errors would have caused an error or a timeout in the sub-call
|
||||
|
||||
def test_accelerator_config_empty(self):
|
||||
# Checks that a config can be made with the defaults if not passed
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
eval_dataset = SampleIterableDataset()
|
||||
|
||||
# Leaves one option as something *not* basic
|
||||
args = RegressionTrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
)
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertEqual(trainer.accelerator.split_batches, False)
|
||||
self.assertEqual(trainer.accelerator.dispatch_batches, None)
|
||||
self.assertEqual(trainer.accelerator.even_batches, True)
|
||||
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
|
||||
|
||||
def test_accelerator_config_from_dict(self):
|
||||
# Checks that accelerator kwargs can be passed through
|
||||
# and the accelerator is initialized respectively
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
eval_dataset = SampleIterableDataset()
|
||||
|
||||
# Leaves all options as something *not* basic
|
||||
args = RegressionTrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
accelerator_config={
|
||||
"split_batches": True,
|
||||
"dispatch_batches": True,
|
||||
"even_batches": False,
|
||||
"use_seedable_sampler": True,
|
||||
},
|
||||
)
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertEqual(trainer.accelerator.split_batches, True)
|
||||
self.assertEqual(trainer.accelerator.dispatch_batches, True)
|
||||
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
|
||||
|
||||
def test_accelerator_config_from_yaml(self):
|
||||
# Checks that accelerator kwargs can be passed through
|
||||
# and the accelerator is initialized respectively
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
path_file = Path(tmp_dir) / "accelerator_config.json"
|
||||
with open(path_file, "w") as f:
|
||||
accelerator_config = {
|
||||
"split_batches": True,
|
||||
"dispatch_batches": True,
|
||||
"even_batches": False,
|
||||
"use_seedable_sampler": False,
|
||||
}
|
||||
json.dump(accelerator_config, f)
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
eval_dataset = SampleIterableDataset()
|
||||
|
||||
# Leaves all options as something *not* basic
|
||||
args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=path_file)
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertEqual(trainer.accelerator.split_batches, True)
|
||||
self.assertEqual(trainer.accelerator.dispatch_batches, True)
|
||||
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
|
||||
|
||||
def test_accelerator_config_from_dataclass(self):
|
||||
# Checks that accelerator kwargs can be passed through
|
||||
# and the accelerator is initialized respectively
|
||||
accelerator_config = AcceleratorConfig(
|
||||
split_batches=True, dispatch_batches=True, even_batches=False, use_seedable_sampler=False
|
||||
)
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
eval_dataset = SampleIterableDataset()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config)
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertEqual(trainer.accelerator.split_batches, True)
|
||||
self.assertEqual(trainer.accelerator.dispatch_batches, True)
|
||||
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
|
||||
|
||||
def test_accelerator_config_from_partial(self):
|
||||
# Checks that accelerator kwargs can be passed through
|
||||
# and the accelerator is initialized respectively
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
eval_dataset = SampleIterableDataset()
|
||||
|
||||
# Leaves one option as something *not* basic
|
||||
args = RegressionTrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
accelerator_config={
|
||||
"split_batches": True,
|
||||
},
|
||||
)
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertEqual(trainer.accelerator.split_batches, True)
|
||||
self.assertEqual(trainer.accelerator.dispatch_batches, None)
|
||||
self.assertEqual(trainer.accelerator.even_batches, True)
|
||||
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
|
||||
|
||||
def test_accelerator_config_from_dict_with_deprecated_args(self):
|
||||
# Checks that accelerator kwargs can be passed through
|
||||
# and the accelerator is initialized respectively
|
||||
# and maintains the deprecated args if passed in
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
eval_dataset = SampleIterableDataset()
|
||||
|
||||
# Leaves all options as something *not* basic
|
||||
with self.assertWarns(FutureWarning) as cm:
|
||||
args = RegressionTrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
accelerator_config={
|
||||
"split_batches": True,
|
||||
},
|
||||
dispatch_batches=False,
|
||||
)
|
||||
self.assertIn("dispatch_batches", str(cm.warnings[0].message))
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertEqual(trainer.accelerator.dispatch_batches, False)
|
||||
self.assertEqual(trainer.accelerator.split_batches, True)
|
||||
with self.assertWarns(FutureWarning) as cm:
|
||||
args = RegressionTrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
accelerator_config={
|
||||
"even_batches": False,
|
||||
},
|
||||
split_batches=True,
|
||||
)
|
||||
self.assertIn("split_batches", str(cm.warnings[0].message))
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertEqual(trainer.accelerator.split_batches, True)
|
||||
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||
self.assertEqual(trainer.accelerator.dispatch_batches, None)
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user