[GPT2] Correct gradient checkpointing (#9308)
* correct gpt2 * fix gpt2 * fix use_cache ordering * correct past tolerance * fix for all cases * style
This commit is contained in:
committed by
GitHub
parent
21fc676645
commit
61443cd7d9
@@ -247,7 +247,7 @@ class TFGPT2ModelTester:
|
||||
output_from_past_slice = output_from_past[:, :, random_slice_idx]
|
||||
|
||||
# test that outputs are equal for slice
|
||||
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
|
||||
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
|
||||
|
||||
def create_and_check_gpt2_lm_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = TFGPT2LMHeadModel(config=config)
|
||||
|
||||
Reference in New Issue
Block a user