[Flax] Add remat (gradient checkpointing) (#17843)

* [Flax] Add remat (gradient checkpointing)

* fix variable naming in test

* flip: checkpoint using a method

* fix naming

* fix class naming

* apply PVP's suggestions from code review

* make fix-copies

* fix big-bird, electra, roberta

* cookie-cutter

* fix flax big-bird

* move test to common
This commit is contained in:
Sanchit Gandhi
2022-07-01 18:33:54 +01:00
committed by GitHub
parent 664688b94f
commit 485bbe79d5
7 changed files with 414 additions and 96 deletions

View File

@@ -1099,6 +1099,33 @@ class FlaxModelTesterMixin:
for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()):
self.assertTrue(np.allclose(np.array(p1), np.array(p2)))
def test_gradient_checkpointing(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
remat_model = model_class(config)
try:
remat_model.enable_gradient_checkpointing()
except NotImplementedError:
continue
outputs = model(**prepared_inputs_dict)
remat_outputs = remat_model(**prepared_inputs_dict)
# ensure that the dicts of outputs contain the same keys
self.assertEqual(outputs.keys(), remat_outputs.keys())
outputs = outputs.to_tuple()
remat_outputs = remat_outputs.to_tuple()
# ensure that the outputs remain precisely equal
for output, remat_output in zip(outputs, remat_outputs):
self.assertTrue((output == remat_output).all())
@require_flax
@is_staging_test