[TFT5, Cache] Add cache to TFT5 (#3772)

* correct gpt2 test inputs

* make style

* delete modeling_gpt2 change in test file

* translate from pytorch

* correct tests

* fix conflicts

* fix conflicts

* fix conflicts

* fix conflicts

* make tensorflow t5 caching work

* make style

* clean reorder cache

* remove unnecessary spaces

* fix test
This commit is contained in:
Patrick von Platen
2020-04-16 16:14:52 +02:00
committed by GitHub
parent a5b249472e
commit 38f7461df3
6 changed files with 384 additions and 86 deletions

View File

@@ -244,7 +244,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-6))
def create_and_check_t5_decoder_model_attention_mask_past(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
@@ -293,7 +293,6 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
def create_t5_and_check_t5_generate_with_past_key_value_states(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
config.num_layers = 1
model = T5ForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()

View File

@@ -191,7 +191,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
output_from_past_slice = output_from_past[:, 0, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-12)
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
def create_and_check_gpt2_model_attention_mask_past(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args

View File

@@ -24,6 +24,7 @@ from .utils import CACHE_DIR, require_tf, slow
if is_tf_available():
import tensorflow as tf
from transformers import TFT5Model, TFT5ForConditionalGeneration, T5Tokenizer
@@ -111,14 +112,14 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
"decoder_input_ids": input_ids,
"decoder_attention_mask": input_mask,
}
encoder_output, decoder_output = model(inputs)
decoder_output, decoder_past, encoder_output = model(inputs)
encoder_output, decoder_output = model(
decoder_output, decoder_past, encoder_output = model(
input_ids, decoder_attention_mask=input_mask, decoder_input_ids=input_ids
)
result = {
"encoder_output": encoder_output.numpy(),
"decoder_past": decoder_past,
"decoder_output": decoder_output.numpy(),
}
self.parent.assertListEqual(
@@ -127,6 +128,13 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(
list(result["decoder_output"].shape), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertEqual(len(decoder_past), 2)
# decoder_past[0] should correspond to encoder output
self.parent.assertTrue(tf.reduce_all(tf.math.equal(decoder_past[0][0], encoder_output)))
# There should be `num_layers` key value embeddings stored in decoder_past[1]
self.parent.assertEqual(len(decoder_past[1]), config.num_layers)
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple
self.parent.assertEqual(len(decoder_past[1][0]), 4)
def create_and_check_t5_with_lm_head(self, config, input_ids, input_mask, token_labels):
model = TFT5ForConditionalGeneration(config=config)
@@ -136,7 +144,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
"decoder_attention_mask": input_mask,
}
prediction_scores, decoder_output = model(inputs_dict)
prediction_scores, _, _ = model(inputs_dict)
result = {
"prediction_scores": prediction_scores.numpy(),
@@ -145,6 +153,76 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size]
)
def create_and_check_t5_decoder_model_past(self, config, input_ids, decoder_input_ids, attention_mask):
model = TFT5Model(config=config).get_decoder()
input_ids = input_ids[:1, :]
self.batch_size = 1
# first forward pass
_, past_key_value_states = model(input_ids, use_cache=True)
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# append to next input_ids and
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
output_from_no_past = model(next_input_ids)[0]
output_from_past = model(next_tokens, past_key_value_states=past_key_value_states)[0]
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx]
output_from_past_slice = output_from_past[:, 0, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
def create_and_check_t5_decoder_model_attention_mask_past(
self, config, input_ids, decoder_input_ids, attention_mask
):
model = TFT5Model(config=config).get_decoder()
# create attention mask
half_seq_length = self.seq_length // 2
attn_mask_begin = tf.ones((self.batch_size, half_seq_length), dtype=tf.int32)
attn_mask_end = tf.zeros((self.batch_size, self.seq_length - half_seq_length), dtype=tf.int32)
attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1)
# first forward pass
_, past_key_value_states = model(input_ids, attention_mask=attn_mask, use_cache=True)
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# change a random masked slice from input_ids
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).numpy() + 1
random_other_next_tokens = ids_tensor((self.batch_size, self.seq_length), config.vocab_size)
vector_condition = tf.range(self.seq_length) == (self.seq_length - random_seq_idx_to_change)
condition = tf.transpose(
tf.broadcast_to(tf.expand_dims(vector_condition, -1), (self.seq_length, self.batch_size))
)
input_ids = tf.where(condition, random_other_next_tokens, input_ids)
# append to next input_ids and attn_mask
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
attn_mask = tf.concat([attn_mask, tf.ones((attn_mask.shape[0], 1), dtype=tf.int32)], axis=1,)
# get two different outputs
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)[0]
output_from_past = model(
next_tokens, past_key_value_states=past_key_value_states, attention_mask=attn_mask
)[0]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).numpy().item()
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx]
output_from_past_slice = output_from_past[:, 0, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, token_labels) = config_and_inputs
@@ -152,6 +230,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
"inputs": input_ids,
"decoder_input_ids": input_ids,
"decoder_attention_mask": input_mask,
"use_cache": tf.convert_to_tensor([False]),
}
return config, inputs_dict
@@ -170,6 +249,14 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs)
def test_t5_decoder_model_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_decoder_model_past(*config_and_inputs)
def test_t5_decoder_model_past_with_attn_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_decoder_model_attention_mask_past(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in ["t5-small"]: