[GPT2] Correct gradient checkpointing (#9308)

* correct gpt2

* fix gpt2

* fix use_cache ordering

* correct past tolerance

* fix for all cases

* style
This commit is contained in:
Patrick von Platen
2020-12-25 23:28:12 +01:00
committed by GitHub
parent 21fc676645
commit 61443cd7d9
3 changed files with 15 additions and 11 deletions

View File

@@ -233,6 +233,7 @@ class ModelTesterMixin:
return
config.gradient_checkpointing = True
config.use_cache = False
config.return_dict = True
for model_class in self.all_model_classes: