committed by
GitHub
parent
5792c459ed
commit
1759bb9126
@@ -16,9 +16,10 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
from pytest import mark
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, is_torch_available, pipeline
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, HybridCache, is_torch_available, pipeline
|
||||
from transformers.testing_utils import (
|
||||
require_flash_attn,
|
||||
require_read_token,
|
||||
@@ -59,7 +60,7 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = ()
|
||||
all_generative_model_classes = (Gemma2ForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": Gemma2Model,
|
||||
@@ -89,6 +90,101 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
|
||||
def test_eager_matches_sdpa_inference(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache which is not compatible with dola decoding")
|
||||
def test_dola_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
||||
@unittest.skip("Gemma2 has HybridCache and doesn't support old tuple format at all")
|
||||
def test_new_cache_format(self, num_beams, do_sample):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache and doesn't support continue from past kv")
|
||||
def test_generate_continue_from_past_key_values(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache and doesn't support low_memory generation")
|
||||
def test_beam_search_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
|
||||
def test_contrastive_generate(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
|
||||
def test_contrastive_generate_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
|
||||
def test_generate_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
# overwrite because HybridCache has fixed length for key/values
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
||||
)
|
||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
||||
|
||||
for idx, iter_attentions in enumerate(attentions):
|
||||
tgt_len = min_length + idx if not use_cache else 1
|
||||
src_len = min_length + idx if not use_cache else max_length
|
||||
|
||||
expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
config.num_attention_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
)
|
||||
# check attn size
|
||||
self.assertListEqual(
|
||||
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
|
||||
)
|
||||
|
||||
# overwrite because HybridCache has fixed length for key/values
|
||||
def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1):
|
||||
self.assertIsInstance(past_key_values, HybridCache)
|
||||
|
||||
# check shape key, value (batch, head, max_seq_length, head_features)
|
||||
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||
num_key_value_heads = (
|
||||
config.num_attention_heads
|
||||
if getattr(config, "num_key_value_heads", None) is None
|
||||
else config.num_key_value_heads
|
||||
)
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
|
||||
# we should get `max_length` in shape, not `max_length - embeds_length`
|
||||
# `+1` because the test in Mixin subtracts 1 which is needed for tuple cache
|
||||
static_cache_shape = (batch_size, num_key_value_heads, seq_length + 1, head_dim)
|
||||
static_layers = [layer_idx for layer_idx, boolean in enumerate(past_key_values.is_sliding) if not boolean]
|
||||
self.assertTrue(len(past_key_values.key_cache) == num_hidden_layers)
|
||||
self.assertTrue(past_key_values.key_cache[static_layers[0]].shape == static_cache_shape)
|
||||
|
||||
@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different")
|
||||
def test_sdpa_equivalence(self):
|
||||
pass
|
||||
@@ -203,6 +299,5 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=100, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
|
||||
print(output_text)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
Reference in New Issue
Block a user