Allow GradientAccumulationPlugin to be configured from AcceleratorConfig (#29589)

* add gradient_accumulation_kwargs to AcceleratorConfig

* add suggestions from @muellerzr to docstrings, new behavior and tests

* Documentation suggestions from @muellerz

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

* addressed @muellerzr comments regarding tests and test utils

* moved accelerate version to top of file.

* @muellerzr's variable fix

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

* address @amyeroberts. fix tests and docstrings

* address @amyeroberts additional suggestions

---------

Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
Yu Chin Fabian Lim
2024-03-28 22:01:40 +08:00
committed by GitHub
parent a2a7f71604
commit 4df5b9b4b2
5 changed files with 145 additions and 13 deletions

View File

@@ -52,6 +52,7 @@ from .integrations import (
)
from .integrations.deepspeed import is_deepspeed_available
from .utils import (
ACCELERATE_MIN_VERSION,
is_accelerate_available,
is_apex_available,
is_aqlm_available,
@@ -365,11 +366,13 @@ def require_nltk(test_case):
return unittest.skipUnless(is_nltk_available(), "test requires NLTK")(test_case)
def require_accelerate(test_case):
def require_accelerate(test_case, min_version: str = ACCELERATE_MIN_VERSION):
"""
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
"""
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
return unittest.skipUnless(
is_accelerate_available(min_version), f"test requires accelerate version >= {min_version}"
)(test_case)
def require_fsdp(test_case, min_version: str = "1.12.0"):