TF generate refactor - past without encoder outputs (#15944)
* Remove packed past from generation_tf_utils * update models with the new past format * update template accordingly
This commit is contained in:
@@ -116,7 +116,6 @@ class TFBartModelTester:
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
past_key_values = past_key_values[1]
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
|
||||
@@ -114,7 +114,6 @@ class TFBlenderbotModelTester:
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
past_key_values = past_key_values[1]
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
|
||||
@@ -114,7 +114,6 @@ class TFBlenderbotSmallModelTester:
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
past_key_values = past_key_values[1]
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
|
||||
@@ -133,7 +133,6 @@ class TFLEDModelTester:
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
past_key_values = past_key_values[1]
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
|
||||
@@ -116,7 +116,6 @@ class TFMarianModelTester:
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
past_key_values = past_key_values[1]
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
|
||||
@@ -114,7 +114,6 @@ class TFPegasusModelTester:
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
past_key_values = past_key_values[1]
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
|
||||
@@ -182,7 +182,7 @@ class TFSpeech2TextModelTester:
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
|
||||
_, (_, past_key_values) = outputs.to_tuple()
|
||||
_, past_key_values = outputs.to_tuple()
|
||||
|
||||
# create hypothetical multiple next token and extent to next_input_ids
|
||||
next_tokens = tf.math.maximum(ids_tensor((self.batch_size, 3), config.vocab_size), 2)
|
||||
|
||||
@@ -98,13 +98,10 @@ class TFT5ModelTester:
|
||||
encoder_output = result.encoder_last_hidden_state
|
||||
self.parent.assertListEqual(list(encoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertListEqual(list(decoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertEqual(len(decoder_past), 2)
|
||||
# decoder_past[0] should correspond to encoder output
|
||||
self.parent.assertTrue(tf.reduce_all(tf.math.equal(decoder_past[0][0], encoder_output)))
|
||||
# There should be `num_layers` key value embeddings stored in decoder_past[1]
|
||||
self.parent.assertEqual(len(decoder_past[1]), config.num_layers)
|
||||
self.parent.assertEqual(len(decoder_past), config.num_layers)
|
||||
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple
|
||||
self.parent.assertEqual(len(decoder_past[1][0]), 4)
|
||||
self.parent.assertEqual(len(decoder_past[0]), 4)
|
||||
|
||||
def create_and_check_t5_with_lm_head(self, config, input_ids, input_mask, token_labels):
|
||||
model = TFT5ForConditionalGeneration(config=config)
|
||||
|
||||
Reference in New Issue
Block a user