From 918a06e25dfd6f79a20b6f07f63598c71e440161 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 16 May 2023 19:28:19 +0100 Subject: [PATCH] Generate: add test to check KV format (#23403) Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- tests/generation/test_utils.py | 71 +++++++++++++++++++ tests/models/bloom/test_modeling_bloom.py | 4 ++ .../gpt_bigcode/test_modeling_gpt_bigcode.py | 4 ++ .../models/reformer/test_modeling_reformer.py | 6 +- 4 files changed, 84 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 70de057d5f..06e56498d9 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1645,6 +1645,77 @@ class GenerationTesterMixin: self.assertTrue(no_failures) + def test_past_key_values_format(self): + # Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a + # standard KV cache format is important for a consistent API (and for advanced generation methods). + for model_class in self.all_generative_model_classes: + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + + # If it doesn't support cache, pass the test + if not hasattr(config, "use_cache"): + return + + model = model_class(config).to(torch_device) + if "use_cache" not in inputs: + inputs["use_cache"] = True + outputs = model(**inputs) + + # If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format) + if "past_key_values" not in outputs: + return + + num_hidden_layers = ( + getattr(config, "decoder_layers", None) + or getattr(config, "num_decoder_layers", None) + or config.num_hidden_layers + ) + num_attention_heads = getattr(config, "decoder_attention_heads", config.num_attention_heads) + embed_dim = getattr(config, "d_model", config.hidden_size) + per_head_embed_dim = embed_dim // num_attention_heads + + past_kv = outputs["past_key_values"] + self.assertEqual(len(past_kv), num_hidden_layers) + + # Encoder-Decoder checks + if config.is_encoder_decoder: + encoder_num_attention_heads = config.encoder_attention_heads + encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads + batch_size, seq_length = inputs["decoder_input_ids"].shape + for i in range(num_hidden_layers): + self.assertEqual(len(past_kv[i]), 4) # K V for the decoder + K V for the encoder = 4 + self.assertEqual( + past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim) + ) + self.assertEqual( + past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim) + ) + # The sequence length for the encoder K V depends on the model. Since it is not manipulated in + # autoregressive generation, I'm keeping the test general and not checking the 3rd dim + self.assertEqual( + (past_kv[i][2].shape[0], past_kv[i][2].shape[1], past_kv[i][2].shape[3]), + (batch_size, encoder_num_attention_heads, encoder_per_head_embed_dim), + ) + self.assertEqual( + (past_kv[i][3].shape[0], past_kv[i][3].shape[1], past_kv[i][3].shape[3]), + (batch_size, encoder_num_attention_heads, encoder_per_head_embed_dim), + ) + + # Decoder-only checks + else: + # TODO: this line is only needed because of imagegpt, where "pixel_values" = "input_ids". Fix the + # tests in imagegpt such that `prepare_config_and_inputs_for_common` returns the later (and the other + # tests use it) + key = "input_ids" if "input_ids" in inputs else "pixel_values" + batch_size, seq_length = inputs[key].shape + for i in range(num_hidden_layers): + self.assertEqual(len(past_kv[0]), 2) # K V for the decoder = 2 + self.assertEqual( + past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim) + ) + self.assertEqual( + past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim) + ) + def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): batch_size, seq_length = input_ids.shape num_sequences_in_output = batch_size * num_return_sequences diff --git a/tests/models/bloom/test_modeling_bloom.py b/tests/models/bloom/test_modeling_bloom.py index 617998cc61..678c46bd0c 100644 --- a/tests/models/bloom/test_modeling_bloom.py +++ b/tests/models/bloom/test_modeling_bloom.py @@ -393,6 +393,10 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_bloom_weight_initialization(*config_and_inputs) + @unittest.skip("Bloom has a non-standard KV cache format.") + def test_past_key_values_format(self): + pass + @slow def test_model_from_pretrained(self): for model_name in BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py index a9f4d204bf..01e6ceef9e 100644 --- a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py +++ b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py @@ -475,6 +475,10 @@ class GPTBigCodeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste def test_disk_offload(self): pass + @unittest.skip("BigCodeGPT has a non-standard KV cache format.") + def test_past_key_values_format(self): + pass + def test_gpt_bigcode_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_gpt_bigcode_model(*config_and_inputs) diff --git a/tests/models/reformer/test_modeling_reformer.py b/tests/models/reformer/test_modeling_reformer.py index a7f4f2f454..39e1389477 100644 --- a/tests/models/reformer/test_modeling_reformer.py +++ b/tests/models/reformer/test_modeling_reformer.py @@ -831,8 +831,12 @@ class ReformerLSHAttnModelTest( [expected_shape] * len(iter_hidden_states), ) + @unittest.skip("Fails because the sequence length is not a multiple of 4") def test_problem_types(self): - # Fails because the sequence length is not a multiple of 4 + pass + + @unittest.skip("Fails because the sequence length is not a multiple of 4") + def test_past_key_values_format(self): pass