Allow GradientAccumulationPlugin to be configured from AcceleratorConfig (#29589)
* add gradient_accumulation_kwargs to AcceleratorConfig * add suggestions from @muellerzr to docstrings, new behavior and tests * Documentation suggestions from @muellerz Co-authored-by: Zach Mueller <muellerzr@gmail.com> * addressed @muellerzr comments regarding tests and test utils * moved accelerate version to top of file. * @muellerzr's variable fix Co-authored-by: Zach Mueller <muellerzr@gmail.com> * address @amyeroberts. fix tests and docstrings * address @amyeroberts additional suggestions --------- Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com> Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
committed by
GitHub
parent
a2a7f71604
commit
4df5b9b4b2
@@ -24,6 +24,7 @@ import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from functools import partial
|
||||
from itertools import product
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
@@ -92,6 +93,7 @@ from transformers.utils import (
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
is_accelerate_available,
|
||||
is_apex_available,
|
||||
is_bitsandbytes_available,
|
||||
is_safetensors_available,
|
||||
@@ -127,6 +129,9 @@ if is_torch_available():
|
||||
if is_safetensors_available():
|
||||
import safetensors.torch
|
||||
|
||||
# for version specific tests in TrainerIntegrationTest
|
||||
require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28")
|
||||
GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28")
|
||||
|
||||
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
|
||||
|
||||
@@ -2877,6 +2882,10 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(trainer.accelerator.even_batches, True)
|
||||
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
|
||||
|
||||
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
|
||||
# gradient accumulation kwargs configures gradient_state
|
||||
self.assertNotIn("sync_each_batch", trainer.accelerator.gradient_state.plugin_kwargs)
|
||||
|
||||
def test_accelerator_config_from_dict(self):
|
||||
# Checks that accelerator kwargs can be passed through
|
||||
# and the accelerator is initialized respectively
|
||||
@@ -2885,15 +2894,19 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
model = RegressionPreTrainedModel(config)
|
||||
eval_dataset = SampleIterableDataset()
|
||||
|
||||
accelerator_config = {
|
||||
"split_batches": True,
|
||||
"dispatch_batches": True,
|
||||
"even_batches": False,
|
||||
"use_seedable_sampler": True,
|
||||
}
|
||||
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
|
||||
accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True}
|
||||
|
||||
# Leaves all options as something *not* basic
|
||||
args = RegressionTrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
accelerator_config={
|
||||
"split_batches": True,
|
||||
"dispatch_batches": True,
|
||||
"even_batches": False,
|
||||
"use_seedable_sampler": True,
|
||||
},
|
||||
accelerator_config=accelerator_config,
|
||||
)
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertEqual(trainer.accelerator.split_batches, True)
|
||||
@@ -2901,6 +2914,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
|
||||
|
||||
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
|
||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
|
||||
|
||||
def test_accelerator_config_from_yaml(self):
|
||||
# Checks that accelerator kwargs can be passed through
|
||||
# and the accelerator is initialized respectively
|
||||
@@ -2913,6 +2929,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
"even_batches": False,
|
||||
"use_seedable_sampler": False,
|
||||
}
|
||||
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
|
||||
accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True}
|
||||
json.dump(accelerator_config, f)
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
@@ -2926,11 +2944,18 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
|
||||
|
||||
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
|
||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
|
||||
|
||||
def test_accelerator_config_from_dataclass(self):
|
||||
# Checks that accelerator kwargs can be passed through
|
||||
# and the accelerator is initialized respectively
|
||||
|
||||
accelerator_config = AcceleratorConfig(
|
||||
split_batches=True, dispatch_batches=True, even_batches=False, use_seedable_sampler=False
|
||||
split_batches=True,
|
||||
dispatch_batches=True,
|
||||
even_batches=False,
|
||||
use_seedable_sampler=False,
|
||||
)
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
@@ -2943,6 +2968,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
|
||||
|
||||
@require_accelerate_version_min_0_28
|
||||
def test_accelerate_config_from_dataclass_grad_accum(self):
|
||||
# Checks that accelerator kwargs can be passed through
|
||||
# and the accelerator is initialized respectively
|
||||
|
||||
grad_acc_kwargs = {
|
||||
"num_steps": 10,
|
||||
"adjust_scheduler": False,
|
||||
"sync_with_dataloader": False,
|
||||
"sync_each_batch": True,
|
||||
}
|
||||
accelerator_config = AcceleratorConfig(
|
||||
split_batches=True,
|
||||
dispatch_batches=True,
|
||||
even_batches=False,
|
||||
use_seedable_sampler=False,
|
||||
gradient_accumulation_kwargs=grad_acc_kwargs,
|
||||
)
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
eval_dataset = SampleIterableDataset()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config)
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 10)
|
||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["adjust_scheduler"], False)
|
||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_with_dataloader"], False)
|
||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
|
||||
|
||||
def test_accelerator_config_from_partial(self):
|
||||
# Checks that accelerator kwargs can be passed through
|
||||
# and the accelerator is initialized respectively
|
||||
@@ -3014,6 +3068,44 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertEqual(trainer.accelerator.split_batches, True)
|
||||
|
||||
@require_accelerate_version_min_0_28
|
||||
def test_accelerator_config_from_dict_grad_accum_num_steps(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
eval_dataset = SampleIterableDataset()
|
||||
|
||||
# case - TrainingArguments.gradient_accumulation_steps == 1
|
||||
# - gradient_accumulation_kwargs['num_steps] == 1
|
||||
# results in grad accum set to 1
|
||||
args = RegressionTrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
gradient_accumulation_steps=1,
|
||||
accelerator_config={
|
||||
"gradient_accumulation_kwargs": {
|
||||
"num_steps": 1,
|
||||
}
|
||||
},
|
||||
)
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 1)
|
||||
|
||||
# case - TrainingArguments.gradient_accumulation_steps > 1
|
||||
# - gradient_accumulation_kwargs['num_steps] specified
|
||||
# results in exception raised
|
||||
args = RegressionTrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
gradient_accumulation_steps=2,
|
||||
accelerator_config={
|
||||
"gradient_accumulation_kwargs": {
|
||||
"num_steps": 10,
|
||||
}
|
||||
},
|
||||
)
|
||||
with self.assertRaises(Exception) as context:
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertTrue("The `AcceleratorConfig`'s `num_steps` is set but" in str(context.exception))
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user