Introduce configured_state arg for accelerator_config (#29781)
* Introduce configured_state * Include note on tuning * Allow for users to have defined a state already * Include tests * Add note on hpam tune * Guard a bit better * Update src/transformers/training_args.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Finish rebase * Finish rebase * Guard carefully * Fixup test * Refactor * Fin refactor * Comment * Update wrt feedback --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -131,6 +131,10 @@ if is_torch_available():
|
||||
# for version specific tests in TrainerIntegrationTest
|
||||
require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28")
|
||||
GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28")
|
||||
if is_accelerate_available():
|
||||
from accelerate import Accelerator
|
||||
from accelerate.state import AcceleratorState
|
||||
|
||||
|
||||
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
|
||||
|
||||
@@ -3266,6 +3270,16 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertEqual(trainer.accelerator.split_batches, True)
|
||||
|
||||
def test_accelerator_custom_state(self):
|
||||
AcceleratorState._reset_state(reset_partial_state=True)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
_ = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config={"use_configured_state": True})
|
||||
self.assertIn("Please define this beforehand", str(cm.warnings[0].message))
|
||||
_ = Accelerator()
|
||||
_ = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config={"use_configured_state": True})
|
||||
AcceleratorState._reset_state(reset_partial_state=True)
|
||||
|
||||
@require_accelerate_version_min_0_28
|
||||
def test_accelerator_config_from_dict_grad_accum_num_steps(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
||||
Reference in New Issue
Block a user