Pipeline: no side-effects on model.config and model.generation_config 🔫 (#33480)

This commit is contained in:
Joao Gante
2024-09-18 15:43:06 +01:00
committed by GitHub
parent fc83a4d459
commit 7542fac2c7
13 changed files with 132 additions and 30 deletions

View File

@@ -1229,6 +1229,10 @@ class GenerationConfig(PushToHubMixin):
""" """
config_dict = model_config.to_dict() config_dict = model_config.to_dict()
config_dict.pop("_from_model_config", None) config_dict.pop("_from_model_config", None)
# Removes all `None` from the model config dict -- this lets the generation config defaults to take hold
config_dict = {key: value for key, value in config_dict.items() if value is not None}
generation_config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True) generation_config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the # Special case: some models have generation attributes set in the decoder. Use them if still unset in the

View File

@@ -1334,23 +1334,26 @@ class GenerationMixin:
# the following conditions must be met # the following conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field); # 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same); # 2) the generation config must have seen no modification since its creation (the hash is the same);
# 3) the user must have set generation parameters in the model config. # 3) there are non-default generation parameters in the model config.
# 4) the user must have set new generation parameters in the model config.
# NOTE: `torch.compile` can't compile `hash`, this legacy support is disabled with compilation. # NOTE: `torch.compile` can't compile `hash`, this legacy support is disabled with compilation.
if ( if (
not is_torchdynamo_compiling() not is_torchdynamo_compiling()
and self.generation_config._from_model_config # 1) and self.generation_config._from_model_config # 1)
and self.generation_config._original_object_hash == hash(self.generation_config) # 2) and self.generation_config._original_object_hash == hash(self.generation_config) # 2)
and len(self.config._get_non_default_generation_parameters()) > 0 # 3)
): ):
new_generation_config = GenerationConfig.from_model_config(self.config) new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config: # 3) if new_generation_config != self.generation_config: # 4)
warnings.warn( warnings.warn(
"You have modified the pretrained model configuration to control generation. This is a" "You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version." " deprecated strategy to control generation and will be removed in v5."
" Please use and modify the model generation configuration (see" " Please use and modify the model generation configuration (see"
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )" " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )",
UserWarning,
) )
self.generation_config = new_generation_config self.generation_config = new_generation_config
using_model_generation_config = True
generation_config = self.generation_config generation_config = self.generation_config
using_model_generation_config = True using_model_generation_config = True

View File

@@ -501,6 +501,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
else: else:
generate_kwargs["num_frames"] = num_frames generate_kwargs["num_frames"] = num_frames
# User-defined `generation_config` passed to the pipeline call take precedence
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config
tokens = self.model.generate( tokens = self.model.generate(
inputs=inputs, inputs=inputs,
attention_mask=attention_mask, attention_mask=attention_mask,

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections import collections
import copy
import csv import csv
import importlib import importlib
import json import json
@@ -899,22 +900,26 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
): ):
self.model.to(self.device) self.model.to(self.device)
# Update config and generation_config with task specific parameters # If the model can generate, create a local generation config. This is done to avoid side-effects on the model
task_specific_params = self.model.config.task_specific_params # as we apply local tweaks to the generation config.
if task_specific_params is not None and task in task_specific_params: if self.model.can_generate():
self.model.config.update(task_specific_params.get(task)) self.prefix = self.model.config.prefix if hasattr(self.model.config, "prefix") else None
if self.model.can_generate(): self.generation_config = copy.deepcopy(self.model.generation_config)
self.model.generation_config.update(**task_specific_params.get(task)) # Update the generation config with task specific params if they exist
# NOTE: `prefix` is pipeline-specific and doesn't exist in the generation config.
# Pipelines calling `generate`: if the tokenizer has a pad token but the model doesn't, set it in the task_specific_params = self.model.config.task_specific_params
# forward params so that `generate` is aware of the pad token. if task_specific_params is not None and task in task_specific_params:
if ( this_task_params = task_specific_params.get(task)
self.tokenizer is not None if "prefix" in this_task_params:
and self.model.can_generate() self.prefix = this_task_params.pop("prefix")
and self.tokenizer.pad_token_id is not None self.generation_config.update(**this_task_params)
and self.model.generation_config.pad_token_id is None # If the tokenizer has a pad token but the model doesn't, set it so that `generate` is aware of it.
): if (
self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id self.tokenizer is not None
and self.tokenizer.pad_token_id is not None
and self.generation_config.pad_token_id is None
):
self.generation_config.pad_token_id = self.tokenizer.pad_token_id
self.call_count = 0 self.call_count = 0
self._batch_size = kwargs.pop("batch_size", None) self._batch_size = kwargs.pop("batch_size", None)

View File

@@ -429,6 +429,10 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
is_last = model_inputs.pop("is_last", False) is_last = model_inputs.pop("is_last", False)
if self.model_type == ModelType.VisionEncoderDecoder: if self.model_type == ModelType.VisionEncoderDecoder:
# User-defined `generation_config` passed to the pipeline call take precedence
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config
model_outputs = self.model.generate(**model_inputs, **generate_kwargs) model_outputs = self.model.generate(**model_inputs, **generate_kwargs)
else: else:
model_outputs = self.model(**model_inputs) model_outputs = self.model(**model_inputs)

View File

@@ -181,6 +181,10 @@ class ImageToTextPipeline(Pipeline):
): ):
model_inputs["input_ids"] = None model_inputs["input_ids"] = None
# User-defined `generation_config` passed to the pipeline call take precedence
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config
# FIXME: We need to pop here due to a difference in how `generation.py` and `generation.tf_utils.py` # FIXME: We need to pop here due to a difference in how `generation.py` and `generation.tf_utils.py`
# parse inputs. In the Tensorflow version, `generate` raises an error if we don't use `input_ids` whereas # parse inputs. In the Tensorflow version, `generate` raises an error if we don't use `input_ids` whereas
# the PyTorch version matches it with `self.model.main_input_name` or `self.model.encoder.main_input_name` # the PyTorch version matches it with `self.model.main_input_name` or `self.model.encoder.main_input_name`

View File

@@ -385,6 +385,10 @@ class TableQuestionAnsweringPipeline(Pipeline):
else: else:
outputs = self.batch_inference(**model_inputs) outputs = self.batch_inference(**model_inputs)
else: else:
# User-defined `generation_config` passed to the pipeline call take precedence
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config
outputs = self.model.generate(**model_inputs, **generate_kwargs) outputs = self.model.generate(**model_inputs, **generate_kwargs)
model_outputs = {"model_inputs": model_inputs, "table": table, "outputs": outputs} model_outputs = {"model_inputs": model_inputs, "table": table, "outputs": outputs}
return model_outputs return model_outputs

View File

@@ -115,7 +115,7 @@ class Text2TextGenerationPipeline(Pipeline):
return True return True
def _parse_and_tokenize(self, *args, truncation): def _parse_and_tokenize(self, *args, truncation):
prefix = self.model.config.prefix if self.model.config.prefix is not None else "" prefix = self.prefix if self.prefix is not None else ""
if isinstance(args[0], list): if isinstance(args[0], list):
if self.tokenizer.pad_token_id is None: if self.tokenizer.pad_token_id is None:
raise ValueError("Please make sure that the tokenizer has a pad_token_id when using a batch input") raise ValueError("Please make sure that the tokenizer has a pad_token_id when using a batch input")
@@ -185,9 +185,14 @@ class Text2TextGenerationPipeline(Pipeline):
self.check_inputs( self.check_inputs(
input_length, input_length,
generate_kwargs.get("min_length", self.model.config.min_length), generate_kwargs.get("min_length", self.generation_config.min_length),
generate_kwargs.get("max_length", self.model.config.max_length), generate_kwargs.get("max_length", self.generation_config.max_length),
) )
# User-defined `generation_config` passed to the pipeline call take precedence
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config
output_ids = self.model.generate(**model_inputs, **generate_kwargs) output_ids = self.model.generate(**model_inputs, **generate_kwargs)
out_b = output_ids.shape[0] out_b = output_ids.shape[0]
if self.framework == "pt": if self.framework == "pt":

View File

@@ -103,8 +103,8 @@ class TextGenerationPipeline(Pipeline):
# It also defines both some preprocess_kwargs and generate_kwargs # It also defines both some preprocess_kwargs and generate_kwargs
# which is why we cannot put them in their respective methods. # which is why we cannot put them in their respective methods.
prefix = None prefix = None
if self.model.config.prefix is not None: if self.prefix is not None:
prefix = self.model.config.prefix prefix = self.prefix
if prefix is None and self.model.__class__.__name__ in [ if prefix is None and self.model.__class__.__name__ in [
"XLNetLMHeadModel", "XLNetLMHeadModel",
"TransfoXLLMHeadModel", "TransfoXLLMHeadModel",
@@ -316,7 +316,7 @@ class TextGenerationPipeline(Pipeline):
if "max_new_tokens" in generate_kwargs: if "max_new_tokens" in generate_kwargs:
new_tokens = generate_kwargs["max_new_tokens"] new_tokens = generate_kwargs["max_new_tokens"]
else: else:
new_tokens = generate_kwargs.get("max_length", self.model.config.max_length) - cur_len new_tokens = generate_kwargs.get("max_length", self.generation_config.max_length) - cur_len
if new_tokens < 0: if new_tokens < 0:
raise ValueError("We cannot infer how many new tokens are expected") raise ValueError("We cannot infer how many new tokens are expected")
if cur_len + new_tokens > self.tokenizer.model_max_length: if cur_len + new_tokens > self.tokenizer.model_max_length:
@@ -354,7 +354,7 @@ class TextGenerationPipeline(Pipeline):
and generate_kwargs["generation_config"].max_new_tokens is not None and generate_kwargs["generation_config"].max_new_tokens is not None
) )
if not has_max_new_tokens: if not has_max_new_tokens:
generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.generation_config.max_length
generate_kwargs["max_length"] += prefix_length generate_kwargs["max_length"] += prefix_length
has_min_new_tokens = "min_new_tokens" in generate_kwargs or ( has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
"generation_config" in generate_kwargs "generation_config" in generate_kwargs
@@ -363,7 +363,10 @@ class TextGenerationPipeline(Pipeline):
if not has_min_new_tokens and "min_length" in generate_kwargs: if not has_min_new_tokens and "min_length" in generate_kwargs:
generate_kwargs["min_length"] += prefix_length generate_kwargs["min_length"] += prefix_length
# BS x SL # User-defined `generation_config` passed to the pipeline call take precedence
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs) generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
out_b = generated_sequence.shape[0] out_b = generated_sequence.shape[0]
if self.framework == "pt": if self.framework == "pt":

View File

@@ -111,7 +111,7 @@ class TextToAudioPipeline(Pipeline):
if self.model.config.model_type == "bark": if self.model.config.model_type == "bark":
# bark Tokenizer is called with BarkProcessor which uses those kwargs # bark Tokenizer is called with BarkProcessor which uses those kwargs
new_kwargs = { new_kwargs = {
"max_length": self.model.generation_config.semantic_config.get("max_input_semantic_length", 256), "max_length": self.generation_config.semantic_config.get("max_input_semantic_length", 256),
"add_special_tokens": False, "add_special_tokens": False,
"return_attention_mask": True, "return_attention_mask": True,
"return_token_type_ids": False, "return_token_type_ids": False,
@@ -137,6 +137,10 @@ class TextToAudioPipeline(Pipeline):
# we expect some kwargs to be additional tensors which need to be on the right device # we expect some kwargs to be additional tensors which need to be on the right device
generate_kwargs = self._ensure_tensor_on_device(generate_kwargs, device=self.device) generate_kwargs = self._ensure_tensor_on_device(generate_kwargs, device=self.device)
# User-defined `generation_config` passed to the pipeline call take precedence
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config
# generate_kwargs get priority over forward_params # generate_kwargs get priority over forward_params
forward_params.update(generate_kwargs) forward_params.update(generate_kwargs)

View File

@@ -162,6 +162,10 @@ class VisualQuestionAnsweringPipeline(Pipeline):
def _forward(self, model_inputs, **generate_kwargs): def _forward(self, model_inputs, **generate_kwargs):
if self.model.can_generate(): if self.model.can_generate():
# User-defined `generation_config` passed to the pipeline call take precedence
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config
model_outputs = self.model.generate(**model_inputs, **generate_kwargs) model_outputs = self.model.generate(**model_inputs, **generate_kwargs)
else: else:
model_outputs = self.model(**model_inputs) model_outputs = self.model(**model_inputs)

View File

@@ -31,6 +31,7 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
DistilBertForSequenceClassification, DistilBertForSequenceClassification,
MaskGenerationPipeline, MaskGenerationPipeline,
T5ForConditionalGeneration,
TextClassificationPipeline, TextClassificationPipeline,
TextGenerationPipeline, TextGenerationPipeline,
TFAutoModelForSequenceClassification, TFAutoModelForSequenceClassification,
@@ -234,6 +235,31 @@ class CommonPipelineTest(unittest.TestCase):
self.assertIsInstance(pipe, TextGenerationPipeline) # Assert successful load self.assertIsInstance(pipe, TextGenerationPipeline) # Assert successful load
@require_torch
def test_pipeline_with_task_parameters_no_side_effects(self):
"""
Regression test: certain pipeline flags, like `task`, modified the model configuration, causing unexpected
side-effects
"""
# This checkpoint has task-specific parameters that will modify the behavior of the pipeline
model = T5ForConditionalGeneration.from_pretrained("t5-small")
self.assertTrue(model.config.num_beams == 1)
# The task-specific parameters used to cause side-effects on `model.config` -- not anymore
pipe = pipeline(model=model, tokenizer=AutoTokenizer.from_pretrained("t5-small"), task="translation_en_to_de")
self.assertTrue(model.config.num_beams == 1)
self.assertTrue(model.generation_config.num_beams == 1)
# Under the hood: we now store a generation config in the pipeline. This generation config stores the
# task-specific paremeters.
self.assertTrue(pipe.generation_config.num_beams == 4)
# We can confirm that the task-specific parameters have an effect. (In this case, the default is `num_beams=1`,
# which would crash when `num_return_sequences=4` is passed.)
pipe("Hugging Face doesn't sell hugs.", num_return_sequences=4)
with self.assertRaises(ValueError):
pipe("Hugging Face doesn't sell hugs.", num_return_sequences=4, num_beams=1)
@is_pipeline_test @is_pipeline_test
class PipelineScikitCompatTest(unittest.TestCase): class PipelineScikitCompatTest(unittest.TestCase):

View File

@@ -1715,6 +1715,38 @@ class ModelUtilsTest(TestCasePlus):
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor)) torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
) )
def test_save_and_load_config_with_custom_generation(self):
"""
Regression test for the ability to save and load a config with a custom generation kwarg (i.e. a parameter
that gets moved to the generation config and reset on the model config)
"""
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
# The default for `num_beams` is 1 and `early_stopping` is False
self.assertTrue(model.config.num_beams == 1)
self.assertTrue(model.config.early_stopping is False)
# When we save the model, this custom parameter should be moved to the generation config AND the model
# config should contain `None`
model.config.num_beams = 2
model.config.early_stopping = True
self.assertTrue(model.generation_config.num_beams == 1) # unmodified generation config
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
new_model = T5ForConditionalGeneration.from_pretrained(tmp_dir)
# moved to generation config
self.assertTrue(new_model.generation_config.num_beams == 2)
self.assertTrue(new_model.generation_config.early_stopping is True)
# reset in the model config
self.assertTrue(new_model.config.num_beams is None)
self.assertTrue(new_model.config.early_stopping is None)
# Sanity check: We can run `generate` with the new model without any warnings
random_ids = torch.randint(0, 100, (1, 5))
with warnings.catch_warnings(record=True) as w:
new_model.generate(random_ids, max_new_tokens=3)
self.assertTrue(len(w) == 0)
@slow @slow
@require_torch @require_torch