[TFT5, Cache] Add cache to TFT5 (#3772)
* correct gpt2 test inputs * make style * delete modeling_gpt2 change in test file * translate from pytorch * correct tests * fix conflicts * fix conflicts * fix conflicts * fix conflicts * make tensorflow t5 caching work * make style * clean reorder cache * remove unnecessary spaces * fix test
This commit is contained in:
committed by
GitHub
parent
a5b249472e
commit
38f7461df3
@@ -244,7 +244,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-6))
|
||||
|
||||
def create_and_check_t5_decoder_model_attention_mask_past(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
@@ -293,7 +293,6 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def create_t5_and_check_t5_generate_with_past_key_value_states(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
):
|
||||
config.num_layers = 1
|
||||
model = T5ForConditionalGeneration(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
@@ -191,7 +191,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
output_from_past_slice = output_from_past[:, 0, random_slice_idx]
|
||||
|
||||
# test that outputs are equal for slice
|
||||
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-12)
|
||||
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
|
||||
|
||||
def create_and_check_gpt2_model_attention_mask_past(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
|
||||
|
||||
@@ -24,6 +24,7 @@ from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
from transformers import TFT5Model, TFT5ForConditionalGeneration, T5Tokenizer
|
||||
|
||||
|
||||
@@ -111,14 +112,14 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
"decoder_input_ids": input_ids,
|
||||
"decoder_attention_mask": input_mask,
|
||||
}
|
||||
encoder_output, decoder_output = model(inputs)
|
||||
decoder_output, decoder_past, encoder_output = model(inputs)
|
||||
|
||||
encoder_output, decoder_output = model(
|
||||
decoder_output, decoder_past, encoder_output = model(
|
||||
input_ids, decoder_attention_mask=input_mask, decoder_input_ids=input_ids
|
||||
)
|
||||
|
||||
result = {
|
||||
"encoder_output": encoder_output.numpy(),
|
||||
"decoder_past": decoder_past,
|
||||
"decoder_output": decoder_output.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
@@ -127,6 +128,13 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
self.parent.assertListEqual(
|
||||
list(result["decoder_output"].shape), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertEqual(len(decoder_past), 2)
|
||||
# decoder_past[0] should correspond to encoder output
|
||||
self.parent.assertTrue(tf.reduce_all(tf.math.equal(decoder_past[0][0], encoder_output)))
|
||||
# There should be `num_layers` key value embeddings stored in decoder_past[1]
|
||||
self.parent.assertEqual(len(decoder_past[1]), config.num_layers)
|
||||
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple
|
||||
self.parent.assertEqual(len(decoder_past[1][0]), 4)
|
||||
|
||||
def create_and_check_t5_with_lm_head(self, config, input_ids, input_mask, token_labels):
|
||||
model = TFT5ForConditionalGeneration(config=config)
|
||||
@@ -136,7 +144,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
"decoder_attention_mask": input_mask,
|
||||
}
|
||||
|
||||
prediction_scores, decoder_output = model(inputs_dict)
|
||||
prediction_scores, _, _ = model(inputs_dict)
|
||||
|
||||
result = {
|
||||
"prediction_scores": prediction_scores.numpy(),
|
||||
@@ -145,6 +153,76 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
|
||||
def create_and_check_t5_decoder_model_past(self, config, input_ids, decoder_input_ids, attention_mask):
|
||||
model = TFT5Model(config=config).get_decoder()
|
||||
|
||||
input_ids = input_ids[:1, :]
|
||||
self.batch_size = 1
|
||||
|
||||
# first forward pass
|
||||
_, past_key_value_states = model(input_ids, use_cache=True)
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
||||
# append to next input_ids and
|
||||
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
|
||||
|
||||
output_from_no_past = model(next_input_ids)[0]
|
||||
output_from_past = model(next_tokens, past_key_value_states=past_key_value_states)[0]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
|
||||
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx]
|
||||
output_from_past_slice = output_from_past[:, 0, random_slice_idx]
|
||||
|
||||
# test that outputs are equal for slice
|
||||
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
|
||||
|
||||
def create_and_check_t5_decoder_model_attention_mask_past(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask
|
||||
):
|
||||
model = TFT5Model(config=config).get_decoder()
|
||||
|
||||
# create attention mask
|
||||
half_seq_length = self.seq_length // 2
|
||||
attn_mask_begin = tf.ones((self.batch_size, half_seq_length), dtype=tf.int32)
|
||||
attn_mask_end = tf.zeros((self.batch_size, self.seq_length - half_seq_length), dtype=tf.int32)
|
||||
attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1)
|
||||
|
||||
# first forward pass
|
||||
_, past_key_value_states = model(input_ids, attention_mask=attn_mask, use_cache=True)
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
||||
# change a random masked slice from input_ids
|
||||
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).numpy() + 1
|
||||
random_other_next_tokens = ids_tensor((self.batch_size, self.seq_length), config.vocab_size)
|
||||
vector_condition = tf.range(self.seq_length) == (self.seq_length - random_seq_idx_to_change)
|
||||
condition = tf.transpose(
|
||||
tf.broadcast_to(tf.expand_dims(vector_condition, -1), (self.seq_length, self.batch_size))
|
||||
)
|
||||
input_ids = tf.where(condition, random_other_next_tokens, input_ids)
|
||||
|
||||
# append to next input_ids and attn_mask
|
||||
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
|
||||
attn_mask = tf.concat([attn_mask, tf.ones((attn_mask.shape[0], 1), dtype=tf.int32)], axis=1,)
|
||||
|
||||
# get two different outputs
|
||||
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)[0]
|
||||
output_from_past = model(
|
||||
next_tokens, past_key_value_states=past_key_value_states, attention_mask=attn_mask
|
||||
)[0]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).numpy().item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx]
|
||||
output_from_past_slice = output_from_past[:, 0, random_slice_idx]
|
||||
|
||||
# test that outputs are equal for slice
|
||||
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids, input_mask, token_labels) = config_and_inputs
|
||||
@@ -152,6 +230,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
"inputs": input_ids,
|
||||
"decoder_input_ids": input_ids,
|
||||
"decoder_attention_mask": input_mask,
|
||||
"use_cache": tf.convert_to_tensor([False]),
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
@@ -170,6 +249,14 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs)
|
||||
|
||||
def test_t5_decoder_model_past(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_t5_decoder_model_past(*config_and_inputs)
|
||||
|
||||
def test_t5_decoder_model_past_with_attn_mask(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_t5_decoder_model_attention_mask_past(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in ["t5-small"]:
|
||||
|
||||
Reference in New Issue
Block a user