[TFT5, Cache] Add cache to TFT5 (#3772)

* correct gpt2 test inputs

* make style

* delete modeling_gpt2 change in test file

* translate from pytorch

* correct tests

* fix conflicts

* fix conflicts

* fix conflicts

* fix conflicts

* make tensorflow t5 caching work

* make style

* clean reorder cache

* remove unnecessary spaces

* fix test
This commit is contained in:
Patrick von Platen
2020-04-16 16:14:52 +02:00
committed by GitHub
parent a5b249472e
commit 38f7461df3
6 changed files with 384 additions and 86 deletions

View File

@@ -244,7 +244,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-6))
def create_and_check_t5_decoder_model_attention_mask_past(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
@@ -293,7 +293,6 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
def create_t5_and_check_t5_generate_with_past_key_value_states(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
config.num_layers = 1
model = T5ForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()