[TFT5, Cache] Add cache to TFT5 (#3772)

* correct gpt2 test inputs

* make style

* delete modeling_gpt2 change in test file

* translate from pytorch

* correct tests

* fix conflicts

* fix conflicts

* fix conflicts

* fix conflicts

* make tensorflow t5 caching work

* make style

* clean reorder cache

* remove unnecessary spaces

* fix test
This commit is contained in:
Patrick von Platen
2020-04-16 16:14:52 +02:00
committed by GitHub
parent a5b249472e
commit 38f7461df3
6 changed files with 384 additions and 86 deletions

View File

@@ -191,7 +191,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
output_from_past_slice = output_from_past[:, 0, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-12)
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
def create_and_check_gpt2_model_attention_mask_past(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args