TF generate refactor - past without encoder outputs (#15944)
* Remove packed past from generation_tf_utils * update models with the new past format * update template accordingly
This commit is contained in:
@@ -98,13 +98,10 @@ class TFT5ModelTester:
|
||||
encoder_output = result.encoder_last_hidden_state
|
||||
self.parent.assertListEqual(list(encoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertListEqual(list(decoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertEqual(len(decoder_past), 2)
|
||||
# decoder_past[0] should correspond to encoder output
|
||||
self.parent.assertTrue(tf.reduce_all(tf.math.equal(decoder_past[0][0], encoder_output)))
|
||||
# There should be `num_layers` key value embeddings stored in decoder_past[1]
|
||||
self.parent.assertEqual(len(decoder_past[1]), config.num_layers)
|
||||
self.parent.assertEqual(len(decoder_past), config.num_layers)
|
||||
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple
|
||||
self.parent.assertEqual(len(decoder_past[1][0]), 4)
|
||||
self.parent.assertEqual(len(decoder_past[0]), 4)
|
||||
|
||||
def create_and_check_t5_with_lm_head(self, config, input_ids, input_mask, token_labels):
|
||||
model = TFT5ForConditionalGeneration(config=config)
|
||||
|
||||
Reference in New Issue
Block a user