Introduce configured_state arg for accelerator_config (#29781)

* Introduce configured_state

* Include note on tuning

* Allow for users to have defined a state already

* Include tests

* Add note on hpam tune

* Guard a bit better

* Update src/transformers/training_args.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Finish rebase

* Finish rebase

* Guard carefully

* Fixup test

* Refactor

* Fin refactor

* Comment

* Update wrt feedback

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Zach Mueller
2024-05-20 09:21:40 -04:00
committed by GitHub
parent bb48e92186
commit 92d1d97c05
3 changed files with 110 additions and 52 deletions

View File

@@ -131,6 +131,10 @@ if is_torch_available():
# for version specific tests in TrainerIntegrationTest
require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28")
GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28")
if is_accelerate_available():
from accelerate import Accelerator
from accelerate.state import AcceleratorState
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
@@ -3266,6 +3270,16 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.split_batches, True)
def test_accelerator_custom_state(self):
AcceleratorState._reset_state(reset_partial_state=True)
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertRaises(ValueError) as cm:
_ = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config={"use_configured_state": True})
self.assertIn("Please define this beforehand", str(cm.warnings[0].message))
_ = Accelerator()
_ = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config={"use_configured_state": True})
AcceleratorState._reset_state(reset_partial_state=True)
@require_accelerate_version_min_0_28
def test_accelerator_config_from_dict_grad_accum_num_steps(self):
with tempfile.TemporaryDirectory() as tmp_dir: