Allow GradientAccumulationPlugin to be configured from AcceleratorConfig (#29589)
* add gradient_accumulation_kwargs to AcceleratorConfig * add suggestions from @muellerzr to docstrings, new behavior and tests * Documentation suggestions from @muellerz Co-authored-by: Zach Mueller <muellerzr@gmail.com> * addressed @muellerzr comments regarding tests and test utils * moved accelerate version to top of file. * @muellerzr's variable fix Co-authored-by: Zach Mueller <muellerzr@gmail.com> * address @amyeroberts. fix tests and docstrings * address @amyeroberts additional suggestions --------- Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com> Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
committed by
GitHub
parent
a2a7f71604
commit
4df5b9b4b2
@@ -52,6 +52,7 @@ from .integrations import (
|
|||||||
)
|
)
|
||||||
from .integrations.deepspeed import is_deepspeed_available
|
from .integrations.deepspeed import is_deepspeed_available
|
||||||
from .utils import (
|
from .utils import (
|
||||||
|
ACCELERATE_MIN_VERSION,
|
||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
is_apex_available,
|
is_apex_available,
|
||||||
is_aqlm_available,
|
is_aqlm_available,
|
||||||
@@ -365,11 +366,13 @@ def require_nltk(test_case):
|
|||||||
return unittest.skipUnless(is_nltk_available(), "test requires NLTK")(test_case)
|
return unittest.skipUnless(is_nltk_available(), "test requires NLTK")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_accelerate(test_case):
|
def require_accelerate(test_case, min_version: str = ACCELERATE_MIN_VERSION):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
|
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
|
||||||
"""
|
"""
|
||||||
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
|
return unittest.skipUnless(
|
||||||
|
is_accelerate_available(min_version), f"test requires accelerate version >= {min_version}"
|
||||||
|
)(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_fsdp(test_case, min_version: str = "1.12.0"):
|
def require_fsdp(test_case, min_version: str = "1.12.0"):
|
||||||
|
|||||||
@@ -4324,8 +4324,23 @@ class Trainer:
|
|||||||
self.repo.git_push()
|
self.repo.git_push()
|
||||||
|
|
||||||
def create_accelerator_and_postprocess(self):
|
def create_accelerator_and_postprocess(self):
|
||||||
grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps}
|
grad_acc_kwargs = {}
|
||||||
|
if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None:
|
||||||
|
grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs
|
||||||
|
|
||||||
|
# check if num_steps is attempted to be passed in gradient_accumulation_kwargs
|
||||||
|
if "num_steps" in grad_acc_kwargs and self.args.gradient_accumulation_steps > 1:
|
||||||
|
# raise because we do not know which setting is intended.
|
||||||
|
raise ValueError(
|
||||||
|
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
|
||||||
|
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
|
||||||
|
)
|
||||||
|
elif "num_steps" not in grad_acc_kwargs:
|
||||||
|
# take the gradient_accumulation_steps setting from TrainingArguments.
|
||||||
|
grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps
|
||||||
|
|
||||||
grad_acc_kwargs["sync_with_dataloader"] = False
|
grad_acc_kwargs["sync_with_dataloader"] = False
|
||||||
|
|
||||||
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
|
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
|
||||||
|
|
||||||
accelerator_config = self.args.accelerator_config.to_dict()
|
accelerator_config = self.args.accelerator_config.to_dict()
|
||||||
@@ -4337,6 +4352,8 @@ class Trainer:
|
|||||||
even_batches=accelerator_config.pop("even_batches"),
|
even_batches=accelerator_config.pop("even_batches"),
|
||||||
use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"),
|
use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"),
|
||||||
)
|
)
|
||||||
|
# this would have been updated above, no need for it anymore
|
||||||
|
accelerator_config.pop("gradient_accumulation_kwargs")
|
||||||
args = {
|
args = {
|
||||||
"deepspeed_plugin": self.args.deepspeed_plugin,
|
"deepspeed_plugin": self.args.deepspeed_plugin,
|
||||||
"gradient_accumulation_plugin": gradient_accumulation_plugin,
|
"gradient_accumulation_plugin": gradient_accumulation_plugin,
|
||||||
|
|||||||
@@ -1185,6 +1185,15 @@ class AcceleratorConfig:
|
|||||||
training results are fully reproducable using a different sampling technique. While seed-to-seed results
|
training results are fully reproducable using a different sampling technique. While seed-to-seed results
|
||||||
may differ, on average the differences are neglible when using multiple different seeds to compare. Should
|
may differ, on average the differences are neglible when using multiple different seeds to compare. Should
|
||||||
also be ran with [`~utils.set_seed`] for the best results.
|
also be ran with [`~utils.set_seed`] for the best results.
|
||||||
|
gradient_accumulation_kwargs (`dict`, *optional*):
|
||||||
|
Additional kwargs to configure gradient accumulation, see [`accelerate.utils.GradientAccumulationPlugin`].
|
||||||
|
Any of the following (optional) keys are acceptable:
|
||||||
|
num_steps (`int`): Will take precedence over [`~.TrainingArguments.gradient_accumulation_steps`] if
|
||||||
|
the latter is set to 1, otherwise an exception will be raised.
|
||||||
|
adjust_scheduler (`bool`): Whether to adjust the scheduler steps to account for [`~.TrainingArguments.gradient_accumulation_steps`].
|
||||||
|
The [`accelerate.utils.GradientAccumulationPlugin`] default is `True`.
|
||||||
|
sync_each_batch (`bool`): Whether to synchronize the gradients at each data batch.
|
||||||
|
The [`accelerate.utils.GradientAccumulationPlugin`] default is `False`.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -1223,6 +1232,19 @@ class AcceleratorConfig:
|
|||||||
"multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results."
|
"multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
gradient_accumulation_kwargs: Optional[Dict] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Additional kwargs to configure gradient accumulation, see [`accelerate.utils.GradientAccumulationPlugin`]. "
|
||||||
|
"Any of the following (optional) keys are acceptable: "
|
||||||
|
" num_steps (`int`): Will take precedence over [`~.TrainingArguments.gradient_accumulation_steps`] if "
|
||||||
|
" the latter is set to 1, otherwise an exception will be raised. "
|
||||||
|
" adjust_scheduler (`bool`): Whether to adjust the scheduler steps to account for [`~.TrainingArguments.gradient_accumulation_steps`]. "
|
||||||
|
" The [`accelerate.utils.GradientAccumulationPlugin`] default is `True`. "
|
||||||
|
" sync_each_batch (`bool`): Whether to synchronize the gradients at each data batch. "
|
||||||
|
" The [`accelerate.utils.GradientAccumulationPlugin`] default is `False`."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_json_file(cls, json_file):
|
def from_json_file(cls, json_file):
|
||||||
|
|||||||
@@ -805,9 +805,7 @@ def is_protobuf_available():
|
|||||||
|
|
||||||
|
|
||||||
def is_accelerate_available(min_version: str = ACCELERATE_MIN_VERSION):
|
def is_accelerate_available(min_version: str = ACCELERATE_MIN_VERSION):
|
||||||
if min_version is not None:
|
return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version)
|
||||||
return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version)
|
|
||||||
return _accelerate_available
|
|
||||||
|
|
||||||
|
|
||||||
def is_fsdp_available(min_version: str = FSDP_MIN_VERSION):
|
def is_fsdp_available(min_version: str = FSDP_MIN_VERSION):
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
from functools import partial
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
@@ -92,6 +93,7 @@ from transformers.utils import (
|
|||||||
SAFE_WEIGHTS_NAME,
|
SAFE_WEIGHTS_NAME,
|
||||||
WEIGHTS_INDEX_NAME,
|
WEIGHTS_INDEX_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
|
is_accelerate_available,
|
||||||
is_apex_available,
|
is_apex_available,
|
||||||
is_bitsandbytes_available,
|
is_bitsandbytes_available,
|
||||||
is_safetensors_available,
|
is_safetensors_available,
|
||||||
@@ -127,6 +129,9 @@ if is_torch_available():
|
|||||||
if is_safetensors_available():
|
if is_safetensors_available():
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
|
|
||||||
|
# for version specific tests in TrainerIntegrationTest
|
||||||
|
require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28")
|
||||||
|
GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28")
|
||||||
|
|
||||||
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
|
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
|
||||||
|
|
||||||
@@ -2877,6 +2882,10 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertEqual(trainer.accelerator.even_batches, True)
|
self.assertEqual(trainer.accelerator.even_batches, True)
|
||||||
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
|
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
|
||||||
|
|
||||||
|
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
|
||||||
|
# gradient accumulation kwargs configures gradient_state
|
||||||
|
self.assertNotIn("sync_each_batch", trainer.accelerator.gradient_state.plugin_kwargs)
|
||||||
|
|
||||||
def test_accelerator_config_from_dict(self):
|
def test_accelerator_config_from_dict(self):
|
||||||
# Checks that accelerator kwargs can be passed through
|
# Checks that accelerator kwargs can be passed through
|
||||||
# and the accelerator is initialized respectively
|
# and the accelerator is initialized respectively
|
||||||
@@ -2885,15 +2894,19 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
model = RegressionPreTrainedModel(config)
|
model = RegressionPreTrainedModel(config)
|
||||||
eval_dataset = SampleIterableDataset()
|
eval_dataset = SampleIterableDataset()
|
||||||
|
|
||||||
|
accelerator_config = {
|
||||||
|
"split_batches": True,
|
||||||
|
"dispatch_batches": True,
|
||||||
|
"even_batches": False,
|
||||||
|
"use_seedable_sampler": True,
|
||||||
|
}
|
||||||
|
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
|
||||||
|
accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True}
|
||||||
|
|
||||||
# Leaves all options as something *not* basic
|
# Leaves all options as something *not* basic
|
||||||
args = RegressionTrainingArguments(
|
args = RegressionTrainingArguments(
|
||||||
output_dir=tmp_dir,
|
output_dir=tmp_dir,
|
||||||
accelerator_config={
|
accelerator_config=accelerator_config,
|
||||||
"split_batches": True,
|
|
||||||
"dispatch_batches": True,
|
|
||||||
"even_batches": False,
|
|
||||||
"use_seedable_sampler": True,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||||
self.assertEqual(trainer.accelerator.split_batches, True)
|
self.assertEqual(trainer.accelerator.split_batches, True)
|
||||||
@@ -2901,6 +2914,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertEqual(trainer.accelerator.even_batches, False)
|
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||||
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
|
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
|
||||||
|
|
||||||
|
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
|
||||||
|
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
|
||||||
|
|
||||||
def test_accelerator_config_from_yaml(self):
|
def test_accelerator_config_from_yaml(self):
|
||||||
# Checks that accelerator kwargs can be passed through
|
# Checks that accelerator kwargs can be passed through
|
||||||
# and the accelerator is initialized respectively
|
# and the accelerator is initialized respectively
|
||||||
@@ -2913,6 +2929,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
"even_batches": False,
|
"even_batches": False,
|
||||||
"use_seedable_sampler": False,
|
"use_seedable_sampler": False,
|
||||||
}
|
}
|
||||||
|
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
|
||||||
|
accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True}
|
||||||
json.dump(accelerator_config, f)
|
json.dump(accelerator_config, f)
|
||||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||||
model = RegressionPreTrainedModel(config)
|
model = RegressionPreTrainedModel(config)
|
||||||
@@ -2926,11 +2944,18 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertEqual(trainer.accelerator.even_batches, False)
|
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||||
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
|
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
|
||||||
|
|
||||||
|
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
|
||||||
|
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
|
||||||
|
|
||||||
def test_accelerator_config_from_dataclass(self):
|
def test_accelerator_config_from_dataclass(self):
|
||||||
# Checks that accelerator kwargs can be passed through
|
# Checks that accelerator kwargs can be passed through
|
||||||
# and the accelerator is initialized respectively
|
# and the accelerator is initialized respectively
|
||||||
|
|
||||||
accelerator_config = AcceleratorConfig(
|
accelerator_config = AcceleratorConfig(
|
||||||
split_batches=True, dispatch_batches=True, even_batches=False, use_seedable_sampler=False
|
split_batches=True,
|
||||||
|
dispatch_batches=True,
|
||||||
|
even_batches=False,
|
||||||
|
use_seedable_sampler=False,
|
||||||
)
|
)
|
||||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||||
model = RegressionPreTrainedModel(config)
|
model = RegressionPreTrainedModel(config)
|
||||||
@@ -2943,6 +2968,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertEqual(trainer.accelerator.even_batches, False)
|
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||||
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
|
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
|
||||||
|
|
||||||
|
@require_accelerate_version_min_0_28
|
||||||
|
def test_accelerate_config_from_dataclass_grad_accum(self):
|
||||||
|
# Checks that accelerator kwargs can be passed through
|
||||||
|
# and the accelerator is initialized respectively
|
||||||
|
|
||||||
|
grad_acc_kwargs = {
|
||||||
|
"num_steps": 10,
|
||||||
|
"adjust_scheduler": False,
|
||||||
|
"sync_with_dataloader": False,
|
||||||
|
"sync_each_batch": True,
|
||||||
|
}
|
||||||
|
accelerator_config = AcceleratorConfig(
|
||||||
|
split_batches=True,
|
||||||
|
dispatch_batches=True,
|
||||||
|
even_batches=False,
|
||||||
|
use_seedable_sampler=False,
|
||||||
|
gradient_accumulation_kwargs=grad_acc_kwargs,
|
||||||
|
)
|
||||||
|
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||||
|
model = RegressionPreTrainedModel(config)
|
||||||
|
eval_dataset = SampleIterableDataset()
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config)
|
||||||
|
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||||
|
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 10)
|
||||||
|
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["adjust_scheduler"], False)
|
||||||
|
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_with_dataloader"], False)
|
||||||
|
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
|
||||||
|
|
||||||
def test_accelerator_config_from_partial(self):
|
def test_accelerator_config_from_partial(self):
|
||||||
# Checks that accelerator kwargs can be passed through
|
# Checks that accelerator kwargs can be passed through
|
||||||
# and the accelerator is initialized respectively
|
# and the accelerator is initialized respectively
|
||||||
@@ -3014,6 +3068,44 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||||
self.assertEqual(trainer.accelerator.split_batches, True)
|
self.assertEqual(trainer.accelerator.split_batches, True)
|
||||||
|
|
||||||
|
@require_accelerate_version_min_0_28
|
||||||
|
def test_accelerator_config_from_dict_grad_accum_num_steps(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||||
|
model = RegressionPreTrainedModel(config)
|
||||||
|
eval_dataset = SampleIterableDataset()
|
||||||
|
|
||||||
|
# case - TrainingArguments.gradient_accumulation_steps == 1
|
||||||
|
# - gradient_accumulation_kwargs['num_steps] == 1
|
||||||
|
# results in grad accum set to 1
|
||||||
|
args = RegressionTrainingArguments(
|
||||||
|
output_dir=tmp_dir,
|
||||||
|
gradient_accumulation_steps=1,
|
||||||
|
accelerator_config={
|
||||||
|
"gradient_accumulation_kwargs": {
|
||||||
|
"num_steps": 1,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||||
|
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 1)
|
||||||
|
|
||||||
|
# case - TrainingArguments.gradient_accumulation_steps > 1
|
||||||
|
# - gradient_accumulation_kwargs['num_steps] specified
|
||||||
|
# results in exception raised
|
||||||
|
args = RegressionTrainingArguments(
|
||||||
|
output_dir=tmp_dir,
|
||||||
|
gradient_accumulation_steps=2,
|
||||||
|
accelerator_config={
|
||||||
|
"gradient_accumulation_kwargs": {
|
||||||
|
"num_steps": 10,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
with self.assertRaises(Exception) as context:
|
||||||
|
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||||
|
self.assertTrue("The `AcceleratorConfig`'s `num_steps` is set but" in str(context.exception))
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
|||||||
Reference in New Issue
Block a user