RWKV: enable generation tests (#31490)

* add rwkv tests

* has_attentions set in individual tests
This commit is contained in:
Joao Gante
2024-06-20 14:15:01 +01:00
committed by GitHub
parent d28e647f28
commit 1fd60fec75
3 changed files with 54 additions and 19 deletions

View File

@@ -464,6 +464,8 @@ class GenerationTesterMixin:
if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching")
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
self.skipTest("Won't fix: model with non-standard dictionary output shapes")
config.use_cache = True
config.is_decoder = True
@@ -624,6 +626,8 @@ class GenerationTesterMixin:
if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching")
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
self.skipTest("Won't fix: model with non-standard dictionary output shapes")
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(

View File

@@ -269,7 +269,7 @@ class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
pipeline_model_mapping = (
{"feature-extraction": RwkvModel, "text-generation": RwkvForCausalLM} if is_torch_available() else {}
)
# all_generative_model_classes = (RwkvForCausalLM,) if is_torch_available() else ()
all_generative_model_classes = (RwkvForCausalLM,) if is_torch_available() else ()
fx_compatible = False
test_missing_keys = False
test_model_parallel = False
@@ -422,6 +422,52 @@ class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
model = RwkvModel.from_pretrained(model_name)
self.assertIsNotNone(model)
def test_beam_sample_generate_dict_output(self):
# This model has a custom attention output shape AND config flags, let's skip those checks
old_has_attentions = self.has_attentions
self.has_attentions = False
super().test_beam_sample_generate_dict_output()
self.has_attentions = old_has_attentions
def test_beam_search_generate_dict_output(self):
# This model has a custom attention output shape AND config flags, let's skip those checks
old_has_attentions = self.has_attentions
self.has_attentions = False
super().test_beam_search_generate_dict_output()
self.has_attentions = old_has_attentions
def test_constrained_beam_search_generate_dict_output(self):
# This model has a custom attention output shape AND config flags, let's skip those checks
old_has_attentions = self.has_attentions
self.has_attentions = False
super().test_constrained_beam_search_generate_dict_output()
self.has_attentions = old_has_attentions
def test_greedy_generate_dict_outputs(self):
# This model has a custom attention output shape AND config flags, let's skip those checks
old_has_attentions = self.has_attentions
self.has_attentions = False
super().test_greedy_generate_dict_outputs()
self.has_attentions = old_has_attentions
def test_group_beam_search_generate_dict_output(self):
# This model has a custom attention output shape AND config flags, let's skip those checks
old_has_attentions = self.has_attentions
self.has_attentions = False
super().test_group_beam_search_generate_dict_output()
self.has_attentions = old_has_attentions
def test_sample_generate_dict_output(self):
# This model has a custom attention output shape AND config flags, let's skip those checks
old_has_attentions = self.has_attentions
self.has_attentions = False
super().test_sample_generate_dict_output()
self.has_attentions = old_has_attentions
@unittest.skip("This model doesn't support padding")
def test_left_padding_compatibility(self):
pass
@unittest.skipIf(
not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"