Allow GradientAccumulationPlugin to be configured from AcceleratorConfig (#29589)
* add gradient_accumulation_kwargs to AcceleratorConfig * add suggestions from @muellerzr to docstrings, new behavior and tests * Documentation suggestions from @muellerz Co-authored-by: Zach Mueller <muellerzr@gmail.com> * addressed @muellerzr comments regarding tests and test utils * moved accelerate version to top of file. * @muellerzr's variable fix Co-authored-by: Zach Mueller <muellerzr@gmail.com> * address @amyeroberts. fix tests and docstrings * address @amyeroberts additional suggestions --------- Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com> Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
committed by
GitHub
parent
a2a7f71604
commit
4df5b9b4b2
@@ -52,6 +52,7 @@ from .integrations import (
|
||||
)
|
||||
from .integrations.deepspeed import is_deepspeed_available
|
||||
from .utils import (
|
||||
ACCELERATE_MIN_VERSION,
|
||||
is_accelerate_available,
|
||||
is_apex_available,
|
||||
is_aqlm_available,
|
||||
@@ -365,11 +366,13 @@ def require_nltk(test_case):
|
||||
return unittest.skipUnless(is_nltk_available(), "test requires NLTK")(test_case)
|
||||
|
||||
|
||||
def require_accelerate(test_case):
|
||||
def require_accelerate(test_case, min_version: str = ACCELERATE_MIN_VERSION):
|
||||
"""
|
||||
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
|
||||
return unittest.skipUnless(
|
||||
is_accelerate_available(min_version), f"test requires accelerate version >= {min_version}"
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_fsdp(test_case, min_version: str = "1.12.0"):
|
||||
|
||||
Reference in New Issue
Block a user