TF generate refactor - past without encoder outputs (#15944)

* Remove packed past from generation_tf_utils

* update models with the new past format

* update template accordingly
This commit is contained in:
Joao Gante
2022-03-08 14:46:44 +00:00
committed by GitHub
parent 62d847602a
commit 70203b5937
30 changed files with 301 additions and 684 deletions

View File

@@ -98,13 +98,10 @@ class TFT5ModelTester:
encoder_output = result.encoder_last_hidden_state
self.parent.assertListEqual(list(encoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(list(decoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertEqual(len(decoder_past), 2)
# decoder_past[0] should correspond to encoder output
self.parent.assertTrue(tf.reduce_all(tf.math.equal(decoder_past[0][0], encoder_output)))
# There should be `num_layers` key value embeddings stored in decoder_past[1]
self.parent.assertEqual(len(decoder_past[1]), config.num_layers)
self.parent.assertEqual(len(decoder_past), config.num_layers)
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple
self.parent.assertEqual(len(decoder_past[1][0]), 4)
self.parent.assertEqual(len(decoder_past[0]), 4)
def create_and_check_t5_with_lm_head(self, config, input_ids, input_mask, token_labels):
model = TFT5ForConditionalGeneration(config=config)