RWKV: enable generation tests (#31490)
* add rwkv tests * has_attentions set in individual tests
This commit is contained in:
@@ -464,6 +464,8 @@ class GenerationTesterMixin:
|
||||
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest("This model doesn't support caching")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
|
||||
self.skipTest("Won't fix: model with non-standard dictionary output shapes")
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
@@ -624,6 +626,8 @@ class GenerationTesterMixin:
|
||||
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest("This model doesn't support caching")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
|
||||
self.skipTest("Won't fix: model with non-standard dictionary output shapes")
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
|
||||
Reference in New Issue
Block a user