[GPT2] Correct gradient checkpointing (#9308)

* correct gpt2

* fix gpt2

* fix use_cache ordering

* correct past tolerance

* fix for all cases

* style
This commit is contained in:
Patrick von Platen
2020-12-25 23:28:12 +01:00
committed by GitHub
parent 21fc676645
commit 61443cd7d9
3 changed files with 15 additions and 11 deletions

View File

@@ -233,6 +233,7 @@ class ModelTesterMixin:
return
config.gradient_checkpointing = True
config.use_cache = False
config.return_dict = True
for model_class in self.all_model_classes:

View File

@@ -247,7 +247,7 @@ class TFGPT2ModelTester:
output_from_past_slice = output_from_past[:, :, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def create_and_check_gpt2_lm_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = TFGPT2LMHeadModel(config=config)