[Trainer] implement gradient_accumulation_steps support in DeepSpeed integration (#10310)

* implement gradient_accumulation_steps support in DeepSpeed integration

* typo

* cleanup

* cleanup
This commit is contained in:
Stas Bekman
2021-02-22 11:15:59 -08:00
committed by GitHub
parent f991daed18
commit eab0afc19c
5 changed files with 162 additions and 27 deletions

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import inspect
import logging
import os
@@ -830,14 +831,49 @@ class TestCasePlus(unittest.TestCase):
def mockenv(**kwargs):
"""
this is a convenience wrapper, that allows this:
this is a convenience wrapper, that allows this ::
@mockenv(RUN_SLOW=True, USE_TF=False)
def test_something():
run_slow = os.getenv("RUN_SLOW", False)
use_tf = os.getenv("USE_TF", False)
@mockenv(RUN_SLOW=True, USE_TF=False) def test_something(): run_slow = os.getenv("RUN_SLOW", False) use_tf =
os.getenv("USE_TF", False)
"""
return unittest.mock.patch.dict(os.environ, kwargs)
# from https://stackoverflow.com/a/34333710/9201239
@contextlib.contextmanager
def mockenv_context(*remove, **update):
"""
Temporarily updates the ``os.environ`` dictionary in-place. Similar to mockenv
The ``os.environ`` dictionary is updated in-place so that the modification is sure to work in all situations.
Args:
remove: Environment variables to remove.
update: Dictionary of environment variables and values to add/update.
"""
env = os.environ
update = update or {}
remove = remove or []
# List of environment variables being updated or removed.
stomped = (set(update.keys()) | set(remove)) & set(env.keys())
# Environment variables and values to restore on exit.
update_after = {k: env[k] for k in stomped}
# Environment variables and values to remove on exit.
remove_after = frozenset(k for k in update if k not in env)
try:
env.update(update)
[env.pop(k, None) for k in remove]
yield
finally:
env.update(update_after)
[env.pop(k) for k in remove_after]
# --- pytest conf functions --- #
# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once

View File

@@ -718,7 +718,7 @@ class Trainer:
def _wrap_model(self, model, training=True):
# already initialized its own DDP and AMP
if self.deepspeed:
return model
return self.deepspeed
# Mixed precision training with apex (torch < 1.6)
if self.use_apex and training:
@@ -996,6 +996,10 @@ class Trainer:
tr_loss += self.training_step(model, inputs)
self._total_flos += float(self.floating_point_ops(inputs))
# Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps
if self.deepspeed:
self.deepspeed.step()
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
steps_in_epoch <= self.args.gradient_accumulation_steps
@@ -1021,7 +1025,7 @@ class Trainer:
# Optimizer step
if self.deepspeed:
self.deepspeed.step()
pass # called outside the loop
elif is_torch_tpu_available():
xm.optimizer_step(self.optimizer)
elif self.use_amp:
@@ -1030,7 +1034,9 @@ class Trainer:
else:
self.optimizer.step()
self.lr_scheduler.step()
if not self.deepspeed:
self.lr_scheduler.step()
model.zero_grad()
self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / steps_in_epoch
@@ -1388,7 +1394,6 @@ class Trainer:
Return:
:obj:`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
inputs = self._prepare_inputs(inputs)
@@ -1401,7 +1406,8 @@ class Trainer:
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.args.gradient_accumulation_steps > 1:
if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
# deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
loss = loss / self.args.gradient_accumulation_steps
if self.use_amp:
@@ -1410,7 +1416,8 @@ class Trainer:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
elif self.deepspeed:
self.deepspeed.backward(loss)
# loss gets scaled under gradient_accumulation_steps in deepspeed
loss = self.deepspeed.backward(loss)
else:
loss.backward()