Pass model_kwargs when loading a model in pipeline() (#12449)
* Pass model_kwargs when loading a model in pipeline * Add test for model_kwargs parameter of pipeline() * Rewrite test to not download model * Fix failing style checks
This commit is contained in:
@@ -61,6 +61,13 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
|
||||
for key in output_keys:
|
||||
self.assertIn(key, result)
|
||||
|
||||
@require_torch
|
||||
def test_model_kwargs_passed_to_model_load(self):
|
||||
ner_pipeline = pipeline(task="ner", model=self.small_models[0])
|
||||
self.assertFalse(ner_pipeline.model.config.output_attentions)
|
||||
ner_pipeline = pipeline(task="ner", model=self.small_models[0], model_kwargs={"output_attentions": True})
|
||||
self.assertTrue(ner_pipeline.model.config.output_attentions)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_spanish_bert(self):
|
||||
|
||||
Reference in New Issue
Block a user