[GPT2] Correct gradient checkpointing (#9308)
* correct gpt2 * fix gpt2 * fix use_cache ordering * correct past tolerance * fix for all cases * style
This commit is contained in:
committed by
GitHub
parent
21fc676645
commit
61443cd7d9
@@ -233,6 +233,7 @@ class ModelTesterMixin:
|
||||
return
|
||||
|
||||
config.gradient_checkpointing = True
|
||||
config.use_cache = False
|
||||
config.return_dict = True
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
|
||||
Reference in New Issue
Block a user