[TFT5, Cache] Add cache to TFT5 (#3772)
* correct gpt2 test inputs * make style * delete modeling_gpt2 change in test file * translate from pytorch * correct tests * fix conflicts * fix conflicts * fix conflicts * fix conflicts * make tensorflow t5 caching work * make style * clean reorder cache * remove unnecessary spaces * fix test
This commit is contained in:
committed by
GitHub
parent
a5b249472e
commit
38f7461df3
@@ -244,7 +244,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-6))
|
||||
|
||||
def create_and_check_t5_decoder_model_attention_mask_past(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
@@ -293,7 +293,6 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def create_t5_and_check_t5_generate_with_past_key_value_states(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
):
|
||||
config.num_layers = 1
|
||||
model = T5ForConditionalGeneration(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
Reference in New Issue
Block a user