Jamba: fix left-padding test (#30389)

fix test
This commit is contained in:
Joao Gante
2024-04-22 17:02:55 +01:00
committed by GitHub
parent f3b3533e19
commit 6c7335e053

View File

@@ -483,7 +483,7 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
return model_kwargs
for model_class in decoder_only_classes:
config, input_ids, attention_mask, _ = self._get_input_ids_and_config()
config, input_ids, attention_mask = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
signature = inspect.signature(model.forward).parameters.keys()