From 1fd60fec75b2f75432cca48ff79d606e084a9bc2 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 20 Jun 2024 14:15:01 +0100 Subject: [PATCH] RWKV: enable generation tests (#31490) * add rwkv tests * has_attentions set in individual tests --- src/transformers/models/rwkv/modeling_rwkv.py | 21 ++------ tests/generation/test_utils.py | 4 ++ tests/models/rwkv/test_modeling_rwkv.py | 48 ++++++++++++++++++- 3 files changed, 54 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index 8568bd999e..f6b8cd412b 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -625,6 +625,9 @@ class RwkvModel(RwkvPreTrainedModel): use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if attention_mask is None: + logger.warning_once("`attention_mask` was passed, but it is unused in this model.") + if self.training == self.layers_are_rescaled: self._rescale_layers() @@ -765,24 +768,6 @@ class RwkvForCausalLM(RwkvPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.head = new_embeddings - def generate(self, *args, **kwargs): - # Thin wrapper to raise exceptions when trying to generate with methods that manipulate `past_key_values`. - # RWKV is one of the few models that don't have it (it has `state` instead, which has different properties and - # usage). - try: - gen_output = super().generate(*args, **kwargs) - except AttributeError as exc: - # Expected exception: "AttributeError: '(object name)' object has no attribute 'past_key_values'" - if "past_key_values" in str(exc): - raise AttributeError( - "You tried to call `generate` with a decoding strategy that manipulates `past_key_values`. RWKV " - "doesn't have that attribute, try another generation strategy instead. For the available " - "generation strategies, check this doc: https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" - ) - else: - raise exc - return gen_output - def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, **kwargs): # only last token for inputs_ids if the state is passed along. if state is not None: diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 6215bc87ed..f61adbbd90 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -464,6 +464,8 @@ class GenerationTesterMixin: if not hasattr(config, "use_cache"): self.skipTest("This model doesn't support caching") + if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]): + self.skipTest("Won't fix: model with non-standard dictionary output shapes") config.use_cache = True config.is_decoder = True @@ -624,6 +626,8 @@ class GenerationTesterMixin: if not hasattr(config, "use_cache"): self.skipTest("This model doesn't support caching") + if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]): + self.skipTest("Won't fix: model with non-standard dictionary output shapes") model = model_class(config).to(torch_device).eval() logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( diff --git a/tests/models/rwkv/test_modeling_rwkv.py b/tests/models/rwkv/test_modeling_rwkv.py index d2a41a863d..47590c98d4 100644 --- a/tests/models/rwkv/test_modeling_rwkv.py +++ b/tests/models/rwkv/test_modeling_rwkv.py @@ -269,7 +269,7 @@ class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin pipeline_model_mapping = ( {"feature-extraction": RwkvModel, "text-generation": RwkvForCausalLM} if is_torch_available() else {} ) - # all_generative_model_classes = (RwkvForCausalLM,) if is_torch_available() else () + all_generative_model_classes = (RwkvForCausalLM,) if is_torch_available() else () fx_compatible = False test_missing_keys = False test_model_parallel = False @@ -422,6 +422,52 @@ class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin model = RwkvModel.from_pretrained(model_name) self.assertIsNotNone(model) + def test_beam_sample_generate_dict_output(self): + # This model has a custom attention output shape AND config flags, let's skip those checks + old_has_attentions = self.has_attentions + self.has_attentions = False + super().test_beam_sample_generate_dict_output() + self.has_attentions = old_has_attentions + + def test_beam_search_generate_dict_output(self): + # This model has a custom attention output shape AND config flags, let's skip those checks + old_has_attentions = self.has_attentions + self.has_attentions = False + super().test_beam_search_generate_dict_output() + self.has_attentions = old_has_attentions + + def test_constrained_beam_search_generate_dict_output(self): + # This model has a custom attention output shape AND config flags, let's skip those checks + old_has_attentions = self.has_attentions + self.has_attentions = False + super().test_constrained_beam_search_generate_dict_output() + self.has_attentions = old_has_attentions + + def test_greedy_generate_dict_outputs(self): + # This model has a custom attention output shape AND config flags, let's skip those checks + old_has_attentions = self.has_attentions + self.has_attentions = False + super().test_greedy_generate_dict_outputs() + self.has_attentions = old_has_attentions + + def test_group_beam_search_generate_dict_output(self): + # This model has a custom attention output shape AND config flags, let's skip those checks + old_has_attentions = self.has_attentions + self.has_attentions = False + super().test_group_beam_search_generate_dict_output() + self.has_attentions = old_has_attentions + + def test_sample_generate_dict_output(self): + # This model has a custom attention output shape AND config flags, let's skip those checks + old_has_attentions = self.has_attentions + self.has_attentions = False + super().test_sample_generate_dict_output() + self.has_attentions = old_has_attentions + + @unittest.skip("This model doesn't support padding") + def test_left_padding_compatibility(self): + pass + @unittest.skipIf( not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"