RWKV: enable generation tests (#31490)
* add rwkv tests * has_attentions set in individual tests
This commit is contained in:
@@ -625,6 +625,9 @@ class RwkvModel(RwkvPreTrainedModel):
|
|||||||
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
|
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
logger.warning_once("`attention_mask` was passed, but it is unused in this model.")
|
||||||
|
|
||||||
if self.training == self.layers_are_rescaled:
|
if self.training == self.layers_are_rescaled:
|
||||||
self._rescale_layers()
|
self._rescale_layers()
|
||||||
|
|
||||||
@@ -765,24 +768,6 @@ class RwkvForCausalLM(RwkvPreTrainedModel):
|
|||||||
def set_output_embeddings(self, new_embeddings):
|
def set_output_embeddings(self, new_embeddings):
|
||||||
self.head = new_embeddings
|
self.head = new_embeddings
|
||||||
|
|
||||||
def generate(self, *args, **kwargs):
|
|
||||||
# Thin wrapper to raise exceptions when trying to generate with methods that manipulate `past_key_values`.
|
|
||||||
# RWKV is one of the few models that don't have it (it has `state` instead, which has different properties and
|
|
||||||
# usage).
|
|
||||||
try:
|
|
||||||
gen_output = super().generate(*args, **kwargs)
|
|
||||||
except AttributeError as exc:
|
|
||||||
# Expected exception: "AttributeError: '(object name)' object has no attribute 'past_key_values'"
|
|
||||||
if "past_key_values" in str(exc):
|
|
||||||
raise AttributeError(
|
|
||||||
"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`. RWKV "
|
|
||||||
"doesn't have that attribute, try another generation strategy instead. For the available "
|
|
||||||
"generation strategies, check this doc: https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise exc
|
|
||||||
return gen_output
|
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, **kwargs):
|
||||||
# only last token for inputs_ids if the state is passed along.
|
# only last token for inputs_ids if the state is passed along.
|
||||||
if state is not None:
|
if state is not None:
|
||||||
|
|||||||
@@ -464,6 +464,8 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
self.skipTest("This model doesn't support caching")
|
self.skipTest("This model doesn't support caching")
|
||||||
|
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
|
||||||
|
self.skipTest("Won't fix: model with non-standard dictionary output shapes")
|
||||||
|
|
||||||
config.use_cache = True
|
config.use_cache = True
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
@@ -624,6 +626,8 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
self.skipTest("This model doesn't support caching")
|
self.skipTest("This model doesn't support caching")
|
||||||
|
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
|
||||||
|
self.skipTest("Won't fix: model with non-standard dictionary output shapes")
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||||
|
|||||||
@@ -269,7 +269,7 @@ class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": RwkvModel, "text-generation": RwkvForCausalLM} if is_torch_available() else {}
|
{"feature-extraction": RwkvModel, "text-generation": RwkvForCausalLM} if is_torch_available() else {}
|
||||||
)
|
)
|
||||||
# all_generative_model_classes = (RwkvForCausalLM,) if is_torch_available() else ()
|
all_generative_model_classes = (RwkvForCausalLM,) if is_torch_available() else ()
|
||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
test_model_parallel = False
|
test_model_parallel = False
|
||||||
@@ -422,6 +422,52 @@ class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
model = RwkvModel.from_pretrained(model_name)
|
model = RwkvModel.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
def test_beam_sample_generate_dict_output(self):
|
||||||
|
# This model has a custom attention output shape AND config flags, let's skip those checks
|
||||||
|
old_has_attentions = self.has_attentions
|
||||||
|
self.has_attentions = False
|
||||||
|
super().test_beam_sample_generate_dict_output()
|
||||||
|
self.has_attentions = old_has_attentions
|
||||||
|
|
||||||
|
def test_beam_search_generate_dict_output(self):
|
||||||
|
# This model has a custom attention output shape AND config flags, let's skip those checks
|
||||||
|
old_has_attentions = self.has_attentions
|
||||||
|
self.has_attentions = False
|
||||||
|
super().test_beam_search_generate_dict_output()
|
||||||
|
self.has_attentions = old_has_attentions
|
||||||
|
|
||||||
|
def test_constrained_beam_search_generate_dict_output(self):
|
||||||
|
# This model has a custom attention output shape AND config flags, let's skip those checks
|
||||||
|
old_has_attentions = self.has_attentions
|
||||||
|
self.has_attentions = False
|
||||||
|
super().test_constrained_beam_search_generate_dict_output()
|
||||||
|
self.has_attentions = old_has_attentions
|
||||||
|
|
||||||
|
def test_greedy_generate_dict_outputs(self):
|
||||||
|
# This model has a custom attention output shape AND config flags, let's skip those checks
|
||||||
|
old_has_attentions = self.has_attentions
|
||||||
|
self.has_attentions = False
|
||||||
|
super().test_greedy_generate_dict_outputs()
|
||||||
|
self.has_attentions = old_has_attentions
|
||||||
|
|
||||||
|
def test_group_beam_search_generate_dict_output(self):
|
||||||
|
# This model has a custom attention output shape AND config flags, let's skip those checks
|
||||||
|
old_has_attentions = self.has_attentions
|
||||||
|
self.has_attentions = False
|
||||||
|
super().test_group_beam_search_generate_dict_output()
|
||||||
|
self.has_attentions = old_has_attentions
|
||||||
|
|
||||||
|
def test_sample_generate_dict_output(self):
|
||||||
|
# This model has a custom attention output shape AND config flags, let's skip those checks
|
||||||
|
old_has_attentions = self.has_attentions
|
||||||
|
self.has_attentions = False
|
||||||
|
super().test_sample_generate_dict_output()
|
||||||
|
self.has_attentions = old_has_attentions
|
||||||
|
|
||||||
|
@unittest.skip("This model doesn't support padding")
|
||||||
|
def test_left_padding_compatibility(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
|
not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
|
||||||
|
|||||||
Reference in New Issue
Block a user