From e7f33e8cb38c446ce4004a742ebb386a4bef7174 Mon Sep 17 00:00:00 2001 From: Alex Hedges Date: Fri, 9 Jul 2021 09:24:55 -0400 Subject: [PATCH] Pass `model_kwargs` when loading a model in `pipeline()` (#12449) * Pass model_kwargs when loading a model in pipeline * Add test for model_kwargs parameter of pipeline() * Rewrite test to not download model * Fix failing style checks --- src/transformers/pipelines/__init__.py | 8 +++++++- tests/test_pipelines_token_classification.py | 7 +++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 64a368a406..feae7a827d 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -426,7 +426,13 @@ def pipeline( # Will load the correct model if possible model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]} framework, model = infer_framework_load_model( - model, model_classes=model_classes, config=config, framework=framework, revision=revision, task=task + model, + model_classes=model_classes, + config=config, + framework=framework, + revision=revision, + task=task, + **model_kwargs, ) model_config = model.config diff --git a/tests/test_pipelines_token_classification.py b/tests/test_pipelines_token_classification.py index 4197dae5da..ce33656314 100644 --- a/tests/test_pipelines_token_classification.py +++ b/tests/test_pipelines_token_classification.py @@ -61,6 +61,13 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest. for key in output_keys: self.assertIn(key, result) + @require_torch + def test_model_kwargs_passed_to_model_load(self): + ner_pipeline = pipeline(task="ner", model=self.small_models[0]) + self.assertFalse(ner_pipeline.model.config.output_attentions) + ner_pipeline = pipeline(task="ner", model=self.small_models[0], model_kwargs={"output_attentions": True}) + self.assertTrue(ner_pipeline.model.config.output_attentions) + @require_torch @slow def test_spanish_bert(self):