Pass model_kwargs when loading a model in pipeline() (#12449)
* Pass model_kwargs when loading a model in pipeline * Add test for model_kwargs parameter of pipeline() * Rewrite test to not download model * Fix failing style checks
This commit is contained in:
@@ -426,7 +426,13 @@ def pipeline(
|
|||||||
# Will load the correct model if possible
|
# Will load the correct model if possible
|
||||||
model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]}
|
model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]}
|
||||||
framework, model = infer_framework_load_model(
|
framework, model = infer_framework_load_model(
|
||||||
model, model_classes=model_classes, config=config, framework=framework, revision=revision, task=task
|
model,
|
||||||
|
model_classes=model_classes,
|
||||||
|
config=config,
|
||||||
|
framework=framework,
|
||||||
|
revision=revision,
|
||||||
|
task=task,
|
||||||
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_config = model.config
|
model_config = model.config
|
||||||
|
|||||||
@@ -61,6 +61,13 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
|
|||||||
for key in output_keys:
|
for key in output_keys:
|
||||||
self.assertIn(key, result)
|
self.assertIn(key, result)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_model_kwargs_passed_to_model_load(self):
|
||||||
|
ner_pipeline = pipeline(task="ner", model=self.small_models[0])
|
||||||
|
self.assertFalse(ner_pipeline.model.config.output_attentions)
|
||||||
|
ner_pipeline = pipeline(task="ner", model=self.small_models[0], model_kwargs={"output_attentions": True})
|
||||||
|
self.assertTrue(ner_pipeline.model.config.output_attentions)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
def test_spanish_bert(self):
|
def test_spanish_bert(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user