RWKV: enable generation tests (#31490)
* add rwkv tests * has_attentions set in individual tests
This commit is contained in:
@@ -269,7 +269,7 @@ class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": RwkvModel, "text-generation": RwkvForCausalLM} if is_torch_available() else {}
|
||||
)
|
||||
# all_generative_model_classes = (RwkvForCausalLM,) if is_torch_available() else ()
|
||||
all_generative_model_classes = (RwkvForCausalLM,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_missing_keys = False
|
||||
test_model_parallel = False
|
||||
@@ -422,6 +422,52 @@ class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
model = RwkvModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_beam_sample_generate_dict_output(self):
|
||||
# This model has a custom attention output shape AND config flags, let's skip those checks
|
||||
old_has_attentions = self.has_attentions
|
||||
self.has_attentions = False
|
||||
super().test_beam_sample_generate_dict_output()
|
||||
self.has_attentions = old_has_attentions
|
||||
|
||||
def test_beam_search_generate_dict_output(self):
|
||||
# This model has a custom attention output shape AND config flags, let's skip those checks
|
||||
old_has_attentions = self.has_attentions
|
||||
self.has_attentions = False
|
||||
super().test_beam_search_generate_dict_output()
|
||||
self.has_attentions = old_has_attentions
|
||||
|
||||
def test_constrained_beam_search_generate_dict_output(self):
|
||||
# This model has a custom attention output shape AND config flags, let's skip those checks
|
||||
old_has_attentions = self.has_attentions
|
||||
self.has_attentions = False
|
||||
super().test_constrained_beam_search_generate_dict_output()
|
||||
self.has_attentions = old_has_attentions
|
||||
|
||||
def test_greedy_generate_dict_outputs(self):
|
||||
# This model has a custom attention output shape AND config flags, let's skip those checks
|
||||
old_has_attentions = self.has_attentions
|
||||
self.has_attentions = False
|
||||
super().test_greedy_generate_dict_outputs()
|
||||
self.has_attentions = old_has_attentions
|
||||
|
||||
def test_group_beam_search_generate_dict_output(self):
|
||||
# This model has a custom attention output shape AND config flags, let's skip those checks
|
||||
old_has_attentions = self.has_attentions
|
||||
self.has_attentions = False
|
||||
super().test_group_beam_search_generate_dict_output()
|
||||
self.has_attentions = old_has_attentions
|
||||
|
||||
def test_sample_generate_dict_output(self):
|
||||
# This model has a custom attention output shape AND config flags, let's skip those checks
|
||||
old_has_attentions = self.has_attentions
|
||||
self.has_attentions = False
|
||||
super().test_sample_generate_dict_output()
|
||||
self.has_attentions = old_has_attentions
|
||||
|
||||
@unittest.skip("This model doesn't support padding")
|
||||
def test_left_padding_compatibility(self):
|
||||
pass
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
|
||||
|
||||
Reference in New Issue
Block a user