Pass model_kwargs when loading a model in pipeline() (#12449)

* Pass model_kwargs when loading a model in pipeline

* Add test for model_kwargs parameter of pipeline()

* Rewrite test to not download model

* Fix failing style checks
This commit is contained in:
Alex Hedges
2021-07-09 09:24:55 -04:00
committed by GitHub
parent 18ca59e1d3
commit e7f33e8cb3
2 changed files with 14 additions and 1 deletions

View File

@@ -61,6 +61,13 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
for key in output_keys:
self.assertIn(key, result)
@require_torch
def test_model_kwargs_passed_to_model_load(self):
ner_pipeline = pipeline(task="ner", model=self.small_models[0])
self.assertFalse(ner_pipeline.model.config.output_attentions)
ner_pipeline = pipeline(task="ner", model=self.small_models[0], model_kwargs={"output_attentions": True})
self.assertTrue(ner_pipeline.model.config.output_attentions)
@require_torch
@slow
def test_spanish_bert(self):