RWKV: enable generation tests (#31490)

* add rwkv tests

* has_attentions set in individual tests
This commit is contained in:
Joao Gante
2024-06-20 14:15:01 +01:00
committed by GitHub
parent d28e647f28
commit 1fd60fec75
3 changed files with 54 additions and 19 deletions

View File

@@ -464,6 +464,8 @@ class GenerationTesterMixin:
if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching")
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
self.skipTest("Won't fix: model with non-standard dictionary output shapes")
config.use_cache = True
config.is_decoder = True
@@ -624,6 +626,8 @@ class GenerationTesterMixin:
if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching")
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
self.skipTest("Won't fix: model with non-standard dictionary output shapes")
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(