[Trainer] implement gradient_accumulation_steps support in DeepSpeed integration (#10310)
* implement gradient_accumulation_steps support in DeepSpeed integration * typo * cleanup * cleanup
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
@@ -830,14 +831,49 @@ class TestCasePlus(unittest.TestCase):
|
||||
|
||||
def mockenv(**kwargs):
|
||||
"""
|
||||
this is a convenience wrapper, that allows this:
|
||||
this is a convenience wrapper, that allows this ::
|
||||
|
||||
@mockenv(RUN_SLOW=True, USE_TF=False)
|
||||
def test_something():
|
||||
run_slow = os.getenv("RUN_SLOW", False)
|
||||
use_tf = os.getenv("USE_TF", False)
|
||||
|
||||
@mockenv(RUN_SLOW=True, USE_TF=False) def test_something(): run_slow = os.getenv("RUN_SLOW", False) use_tf =
|
||||
os.getenv("USE_TF", False)
|
||||
"""
|
||||
return unittest.mock.patch.dict(os.environ, kwargs)
|
||||
|
||||
|
||||
# from https://stackoverflow.com/a/34333710/9201239
|
||||
@contextlib.contextmanager
|
||||
def mockenv_context(*remove, **update):
|
||||
"""
|
||||
Temporarily updates the ``os.environ`` dictionary in-place. Similar to mockenv
|
||||
|
||||
The ``os.environ`` dictionary is updated in-place so that the modification is sure to work in all situations.
|
||||
|
||||
Args:
|
||||
remove: Environment variables to remove.
|
||||
update: Dictionary of environment variables and values to add/update.
|
||||
"""
|
||||
env = os.environ
|
||||
update = update or {}
|
||||
remove = remove or []
|
||||
|
||||
# List of environment variables being updated or removed.
|
||||
stomped = (set(update.keys()) | set(remove)) & set(env.keys())
|
||||
# Environment variables and values to restore on exit.
|
||||
update_after = {k: env[k] for k in stomped}
|
||||
# Environment variables and values to remove on exit.
|
||||
remove_after = frozenset(k for k in update if k not in env)
|
||||
|
||||
try:
|
||||
env.update(update)
|
||||
[env.pop(k, None) for k in remove]
|
||||
yield
|
||||
finally:
|
||||
env.update(update_after)
|
||||
[env.pop(k) for k in remove_after]
|
||||
|
||||
|
||||
# --- pytest conf functions --- #
|
||||
|
||||
# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once
|
||||
|
||||
Reference in New Issue
Block a user