@@ -483,7 +483,7 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
return model_kwargs
|
||||
|
||||
for model_class in decoder_only_classes:
|
||||
config, input_ids, attention_mask, _ = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
signature = inspect.signature(model.forward).parameters.keys()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user