@@ -483,7 +483,7 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
for model_class in decoder_only_classes:
|
for model_class in decoder_only_classes:
|
||||||
config, input_ids, attention_mask, _ = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
signature = inspect.signature(model.forward).parameters.keys()
|
signature = inspect.signature(model.forward).parameters.keys()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user