Pipeline: no side-effects on model.config and model.generation_config 🔫 (#33480)
This commit is contained in:
@@ -1229,6 +1229,10 @@ class GenerationConfig(PushToHubMixin):
|
||||
"""
|
||||
config_dict = model_config.to_dict()
|
||||
config_dict.pop("_from_model_config", None)
|
||||
|
||||
# Removes all `None` from the model config dict -- this lets the generation config defaults to take hold
|
||||
config_dict = {key: value for key, value in config_dict.items() if value is not None}
|
||||
|
||||
generation_config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)
|
||||
|
||||
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
|
||||
|
||||
@@ -1334,23 +1334,26 @@ class GenerationMixin:
|
||||
# the following conditions must be met
|
||||
# 1) the generation config must have been created from the model config (`_from_model_config` field);
|
||||
# 2) the generation config must have seen no modification since its creation (the hash is the same);
|
||||
# 3) the user must have set generation parameters in the model config.
|
||||
# 3) there are non-default generation parameters in the model config.
|
||||
# 4) the user must have set new generation parameters in the model config.
|
||||
# NOTE: `torch.compile` can't compile `hash`, this legacy support is disabled with compilation.
|
||||
if (
|
||||
not is_torchdynamo_compiling()
|
||||
and self.generation_config._from_model_config # 1)
|
||||
and self.generation_config._original_object_hash == hash(self.generation_config) # 2)
|
||||
and len(self.config._get_non_default_generation_parameters()) > 0 # 3)
|
||||
):
|
||||
new_generation_config = GenerationConfig.from_model_config(self.config)
|
||||
if new_generation_config != self.generation_config: # 3)
|
||||
if new_generation_config != self.generation_config: # 4)
|
||||
warnings.warn(
|
||||
"You have modified the pretrained model configuration to control generation. This is a"
|
||||
" deprecated strategy to control generation and will be removed soon, in a future version."
|
||||
" deprecated strategy to control generation and will be removed in v5."
|
||||
" Please use and modify the model generation configuration (see"
|
||||
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
|
||||
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )",
|
||||
UserWarning,
|
||||
)
|
||||
self.generation_config = new_generation_config
|
||||
using_model_generation_config = True
|
||||
|
||||
generation_config = self.generation_config
|
||||
using_model_generation_config = True
|
||||
|
||||
|
||||
@@ -501,6 +501,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
else:
|
||||
generate_kwargs["num_frames"] = num_frames
|
||||
|
||||
# User-defined `generation_config` passed to the pipeline call take precedence
|
||||
if "generation_config" not in generate_kwargs:
|
||||
generate_kwargs["generation_config"] = self.generation_config
|
||||
|
||||
tokens = self.model.generate(
|
||||
inputs=inputs,
|
||||
attention_mask=attention_mask,
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections
|
||||
import copy
|
||||
import csv
|
||||
import importlib
|
||||
import json
|
||||
@@ -899,22 +900,26 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
|
||||
):
|
||||
self.model.to(self.device)
|
||||
|
||||
# Update config and generation_config with task specific parameters
|
||||
# If the model can generate, create a local generation config. This is done to avoid side-effects on the model
|
||||
# as we apply local tweaks to the generation config.
|
||||
if self.model.can_generate():
|
||||
self.prefix = self.model.config.prefix if hasattr(self.model.config, "prefix") else None
|
||||
self.generation_config = copy.deepcopy(self.model.generation_config)
|
||||
# Update the generation config with task specific params if they exist
|
||||
# NOTE: `prefix` is pipeline-specific and doesn't exist in the generation config.
|
||||
task_specific_params = self.model.config.task_specific_params
|
||||
if task_specific_params is not None and task in task_specific_params:
|
||||
self.model.config.update(task_specific_params.get(task))
|
||||
if self.model.can_generate():
|
||||
self.model.generation_config.update(**task_specific_params.get(task))
|
||||
|
||||
# Pipelines calling `generate`: if the tokenizer has a pad token but the model doesn't, set it in the
|
||||
# forward params so that `generate` is aware of the pad token.
|
||||
this_task_params = task_specific_params.get(task)
|
||||
if "prefix" in this_task_params:
|
||||
self.prefix = this_task_params.pop("prefix")
|
||||
self.generation_config.update(**this_task_params)
|
||||
# If the tokenizer has a pad token but the model doesn't, set it so that `generate` is aware of it.
|
||||
if (
|
||||
self.tokenizer is not None
|
||||
and self.model.can_generate()
|
||||
and self.tokenizer.pad_token_id is not None
|
||||
and self.model.generation_config.pad_token_id is None
|
||||
and self.generation_config.pad_token_id is None
|
||||
):
|
||||
self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id
|
||||
self.generation_config.pad_token_id = self.tokenizer.pad_token_id
|
||||
|
||||
self.call_count = 0
|
||||
self._batch_size = kwargs.pop("batch_size", None)
|
||||
|
||||
@@ -429,6 +429,10 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
|
||||
is_last = model_inputs.pop("is_last", False)
|
||||
|
||||
if self.model_type == ModelType.VisionEncoderDecoder:
|
||||
# User-defined `generation_config` passed to the pipeline call take precedence
|
||||
if "generation_config" not in generate_kwargs:
|
||||
generate_kwargs["generation_config"] = self.generation_config
|
||||
|
||||
model_outputs = self.model.generate(**model_inputs, **generate_kwargs)
|
||||
else:
|
||||
model_outputs = self.model(**model_inputs)
|
||||
|
||||
@@ -181,6 +181,10 @@ class ImageToTextPipeline(Pipeline):
|
||||
):
|
||||
model_inputs["input_ids"] = None
|
||||
|
||||
# User-defined `generation_config` passed to the pipeline call take precedence
|
||||
if "generation_config" not in generate_kwargs:
|
||||
generate_kwargs["generation_config"] = self.generation_config
|
||||
|
||||
# FIXME: We need to pop here due to a difference in how `generation.py` and `generation.tf_utils.py`
|
||||
# parse inputs. In the Tensorflow version, `generate` raises an error if we don't use `input_ids` whereas
|
||||
# the PyTorch version matches it with `self.model.main_input_name` or `self.model.encoder.main_input_name`
|
||||
|
||||
@@ -385,6 +385,10 @@ class TableQuestionAnsweringPipeline(Pipeline):
|
||||
else:
|
||||
outputs = self.batch_inference(**model_inputs)
|
||||
else:
|
||||
# User-defined `generation_config` passed to the pipeline call take precedence
|
||||
if "generation_config" not in generate_kwargs:
|
||||
generate_kwargs["generation_config"] = self.generation_config
|
||||
|
||||
outputs = self.model.generate(**model_inputs, **generate_kwargs)
|
||||
model_outputs = {"model_inputs": model_inputs, "table": table, "outputs": outputs}
|
||||
return model_outputs
|
||||
|
||||
@@ -115,7 +115,7 @@ class Text2TextGenerationPipeline(Pipeline):
|
||||
return True
|
||||
|
||||
def _parse_and_tokenize(self, *args, truncation):
|
||||
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
|
||||
prefix = self.prefix if self.prefix is not None else ""
|
||||
if isinstance(args[0], list):
|
||||
if self.tokenizer.pad_token_id is None:
|
||||
raise ValueError("Please make sure that the tokenizer has a pad_token_id when using a batch input")
|
||||
@@ -185,9 +185,14 @@ class Text2TextGenerationPipeline(Pipeline):
|
||||
|
||||
self.check_inputs(
|
||||
input_length,
|
||||
generate_kwargs.get("min_length", self.model.config.min_length),
|
||||
generate_kwargs.get("max_length", self.model.config.max_length),
|
||||
generate_kwargs.get("min_length", self.generation_config.min_length),
|
||||
generate_kwargs.get("max_length", self.generation_config.max_length),
|
||||
)
|
||||
|
||||
# User-defined `generation_config` passed to the pipeline call take precedence
|
||||
if "generation_config" not in generate_kwargs:
|
||||
generate_kwargs["generation_config"] = self.generation_config
|
||||
|
||||
output_ids = self.model.generate(**model_inputs, **generate_kwargs)
|
||||
out_b = output_ids.shape[0]
|
||||
if self.framework == "pt":
|
||||
|
||||
@@ -103,8 +103,8 @@ class TextGenerationPipeline(Pipeline):
|
||||
# It also defines both some preprocess_kwargs and generate_kwargs
|
||||
# which is why we cannot put them in their respective methods.
|
||||
prefix = None
|
||||
if self.model.config.prefix is not None:
|
||||
prefix = self.model.config.prefix
|
||||
if self.prefix is not None:
|
||||
prefix = self.prefix
|
||||
if prefix is None and self.model.__class__.__name__ in [
|
||||
"XLNetLMHeadModel",
|
||||
"TransfoXLLMHeadModel",
|
||||
@@ -316,7 +316,7 @@ class TextGenerationPipeline(Pipeline):
|
||||
if "max_new_tokens" in generate_kwargs:
|
||||
new_tokens = generate_kwargs["max_new_tokens"]
|
||||
else:
|
||||
new_tokens = generate_kwargs.get("max_length", self.model.config.max_length) - cur_len
|
||||
new_tokens = generate_kwargs.get("max_length", self.generation_config.max_length) - cur_len
|
||||
if new_tokens < 0:
|
||||
raise ValueError("We cannot infer how many new tokens are expected")
|
||||
if cur_len + new_tokens > self.tokenizer.model_max_length:
|
||||
@@ -354,7 +354,7 @@ class TextGenerationPipeline(Pipeline):
|
||||
and generate_kwargs["generation_config"].max_new_tokens is not None
|
||||
)
|
||||
if not has_max_new_tokens:
|
||||
generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length
|
||||
generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.generation_config.max_length
|
||||
generate_kwargs["max_length"] += prefix_length
|
||||
has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
|
||||
"generation_config" in generate_kwargs
|
||||
@@ -363,7 +363,10 @@ class TextGenerationPipeline(Pipeline):
|
||||
if not has_min_new_tokens and "min_length" in generate_kwargs:
|
||||
generate_kwargs["min_length"] += prefix_length
|
||||
|
||||
# BS x SL
|
||||
# User-defined `generation_config` passed to the pipeline call take precedence
|
||||
if "generation_config" not in generate_kwargs:
|
||||
generate_kwargs["generation_config"] = self.generation_config
|
||||
|
||||
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
|
||||
out_b = generated_sequence.shape[0]
|
||||
if self.framework == "pt":
|
||||
|
||||
@@ -111,7 +111,7 @@ class TextToAudioPipeline(Pipeline):
|
||||
if self.model.config.model_type == "bark":
|
||||
# bark Tokenizer is called with BarkProcessor which uses those kwargs
|
||||
new_kwargs = {
|
||||
"max_length": self.model.generation_config.semantic_config.get("max_input_semantic_length", 256),
|
||||
"max_length": self.generation_config.semantic_config.get("max_input_semantic_length", 256),
|
||||
"add_special_tokens": False,
|
||||
"return_attention_mask": True,
|
||||
"return_token_type_ids": False,
|
||||
@@ -137,6 +137,10 @@ class TextToAudioPipeline(Pipeline):
|
||||
# we expect some kwargs to be additional tensors which need to be on the right device
|
||||
generate_kwargs = self._ensure_tensor_on_device(generate_kwargs, device=self.device)
|
||||
|
||||
# User-defined `generation_config` passed to the pipeline call take precedence
|
||||
if "generation_config" not in generate_kwargs:
|
||||
generate_kwargs["generation_config"] = self.generation_config
|
||||
|
||||
# generate_kwargs get priority over forward_params
|
||||
forward_params.update(generate_kwargs)
|
||||
|
||||
|
||||
@@ -162,6 +162,10 @@ class VisualQuestionAnsweringPipeline(Pipeline):
|
||||
|
||||
def _forward(self, model_inputs, **generate_kwargs):
|
||||
if self.model.can_generate():
|
||||
# User-defined `generation_config` passed to the pipeline call take precedence
|
||||
if "generation_config" not in generate_kwargs:
|
||||
generate_kwargs["generation_config"] = self.generation_config
|
||||
|
||||
model_outputs = self.model.generate(**model_inputs, **generate_kwargs)
|
||||
else:
|
||||
model_outputs = self.model(**model_inputs)
|
||||
|
||||
@@ -31,6 +31,7 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
DistilBertForSequenceClassification,
|
||||
MaskGenerationPipeline,
|
||||
T5ForConditionalGeneration,
|
||||
TextClassificationPipeline,
|
||||
TextGenerationPipeline,
|
||||
TFAutoModelForSequenceClassification,
|
||||
@@ -234,6 +235,31 @@ class CommonPipelineTest(unittest.TestCase):
|
||||
|
||||
self.assertIsInstance(pipe, TextGenerationPipeline) # Assert successful load
|
||||
|
||||
@require_torch
|
||||
def test_pipeline_with_task_parameters_no_side_effects(self):
|
||||
"""
|
||||
Regression test: certain pipeline flags, like `task`, modified the model configuration, causing unexpected
|
||||
side-effects
|
||||
"""
|
||||
# This checkpoint has task-specific parameters that will modify the behavior of the pipeline
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small")
|
||||
self.assertTrue(model.config.num_beams == 1)
|
||||
|
||||
# The task-specific parameters used to cause side-effects on `model.config` -- not anymore
|
||||
pipe = pipeline(model=model, tokenizer=AutoTokenizer.from_pretrained("t5-small"), task="translation_en_to_de")
|
||||
self.assertTrue(model.config.num_beams == 1)
|
||||
self.assertTrue(model.generation_config.num_beams == 1)
|
||||
|
||||
# Under the hood: we now store a generation config in the pipeline. This generation config stores the
|
||||
# task-specific paremeters.
|
||||
self.assertTrue(pipe.generation_config.num_beams == 4)
|
||||
|
||||
# We can confirm that the task-specific parameters have an effect. (In this case, the default is `num_beams=1`,
|
||||
# which would crash when `num_return_sequences=4` is passed.)
|
||||
pipe("Hugging Face doesn't sell hugs.", num_return_sequences=4)
|
||||
with self.assertRaises(ValueError):
|
||||
pipe("Hugging Face doesn't sell hugs.", num_return_sequences=4, num_beams=1)
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class PipelineScikitCompatTest(unittest.TestCase):
|
||||
|
||||
@@ -1715,6 +1715,38 @@ class ModelUtilsTest(TestCasePlus):
|
||||
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
|
||||
)
|
||||
|
||||
def test_save_and_load_config_with_custom_generation(self):
|
||||
"""
|
||||
Regression test for the ability to save and load a config with a custom generation kwarg (i.e. a parameter
|
||||
that gets moved to the generation config and reset on the model config)
|
||||
"""
|
||||
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
|
||||
|
||||
# The default for `num_beams` is 1 and `early_stopping` is False
|
||||
self.assertTrue(model.config.num_beams == 1)
|
||||
self.assertTrue(model.config.early_stopping is False)
|
||||
|
||||
# When we save the model, this custom parameter should be moved to the generation config AND the model
|
||||
# config should contain `None`
|
||||
model.config.num_beams = 2
|
||||
model.config.early_stopping = True
|
||||
self.assertTrue(model.generation_config.num_beams == 1) # unmodified generation config
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
new_model = T5ForConditionalGeneration.from_pretrained(tmp_dir)
|
||||
# moved to generation config
|
||||
self.assertTrue(new_model.generation_config.num_beams == 2)
|
||||
self.assertTrue(new_model.generation_config.early_stopping is True)
|
||||
# reset in the model config
|
||||
self.assertTrue(new_model.config.num_beams is None)
|
||||
self.assertTrue(new_model.config.early_stopping is None)
|
||||
|
||||
# Sanity check: We can run `generate` with the new model without any warnings
|
||||
random_ids = torch.randint(0, 100, (1, 5))
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
new_model.generate(random_ids, max_new_tokens=3)
|
||||
self.assertTrue(len(w) == 0)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user