Pipeline: no side-effects on model.config and model.generation_config 🔫 (#33480)
This commit is contained in:
@@ -31,6 +31,7 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
DistilBertForSequenceClassification,
|
||||
MaskGenerationPipeline,
|
||||
T5ForConditionalGeneration,
|
||||
TextClassificationPipeline,
|
||||
TextGenerationPipeline,
|
||||
TFAutoModelForSequenceClassification,
|
||||
@@ -234,6 +235,31 @@ class CommonPipelineTest(unittest.TestCase):
|
||||
|
||||
self.assertIsInstance(pipe, TextGenerationPipeline) # Assert successful load
|
||||
|
||||
@require_torch
|
||||
def test_pipeline_with_task_parameters_no_side_effects(self):
|
||||
"""
|
||||
Regression test: certain pipeline flags, like `task`, modified the model configuration, causing unexpected
|
||||
side-effects
|
||||
"""
|
||||
# This checkpoint has task-specific parameters that will modify the behavior of the pipeline
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small")
|
||||
self.assertTrue(model.config.num_beams == 1)
|
||||
|
||||
# The task-specific parameters used to cause side-effects on `model.config` -- not anymore
|
||||
pipe = pipeline(model=model, tokenizer=AutoTokenizer.from_pretrained("t5-small"), task="translation_en_to_de")
|
||||
self.assertTrue(model.config.num_beams == 1)
|
||||
self.assertTrue(model.generation_config.num_beams == 1)
|
||||
|
||||
# Under the hood: we now store a generation config in the pipeline. This generation config stores the
|
||||
# task-specific paremeters.
|
||||
self.assertTrue(pipe.generation_config.num_beams == 4)
|
||||
|
||||
# We can confirm that the task-specific parameters have an effect. (In this case, the default is `num_beams=1`,
|
||||
# which would crash when `num_return_sequences=4` is passed.)
|
||||
pipe("Hugging Face doesn't sell hugs.", num_return_sequences=4)
|
||||
with self.assertRaises(ValueError):
|
||||
pipe("Hugging Face doesn't sell hugs.", num_return_sequences=4, num_beams=1)
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class PipelineScikitCompatTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user