Pipeline: no side-effects on model.config and model.generation_config 🔫 (#33480)

This commit is contained in:
Joao Gante
2024-09-18 15:43:06 +01:00
committed by GitHub
parent fc83a4d459
commit 7542fac2c7
13 changed files with 132 additions and 30 deletions

View File

@@ -1229,6 +1229,10 @@ class GenerationConfig(PushToHubMixin):
"""
config_dict = model_config.to_dict()
config_dict.pop("_from_model_config", None)
# Removes all `None` from the model config dict -- this lets the generation config defaults to take hold
config_dict = {key: value for key, value in config_dict.items() if value is not None}
generation_config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the

View File

@@ -1334,23 +1334,26 @@ class GenerationMixin:
# the following conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same);
# 3) the user must have set generation parameters in the model config.
# 3) there are non-default generation parameters in the model config.
# 4) the user must have set new generation parameters in the model config.
# NOTE: `torch.compile` can't compile `hash`, this legacy support is disabled with compilation.
if (
not is_torchdynamo_compiling()
and self.generation_config._from_model_config # 1)
and self.generation_config._original_object_hash == hash(self.generation_config) # 2)
and len(self.config._get_non_default_generation_parameters()) > 0 # 3)
):
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config: # 3)
if new_generation_config != self.generation_config: # 4)
warnings.warn(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" deprecated strategy to control generation and will be removed in v5."
" Please use and modify the model generation configuration (see"
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )",
UserWarning,
)
self.generation_config = new_generation_config
using_model_generation_config = True
generation_config = self.generation_config
using_model_generation_config = True

View File

@@ -501,6 +501,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
else:
generate_kwargs["num_frames"] = num_frames
# User-defined `generation_config` passed to the pipeline call take precedence
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config
tokens = self.model.generate(
inputs=inputs,
attention_mask=attention_mask,

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import copy
import csv
import importlib
import json
@@ -899,22 +900,26 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
):
self.model.to(self.device)
# Update config and generation_config with task specific parameters
# If the model can generate, create a local generation config. This is done to avoid side-effects on the model
# as we apply local tweaks to the generation config.
if self.model.can_generate():
self.prefix = self.model.config.prefix if hasattr(self.model.config, "prefix") else None
self.generation_config = copy.deepcopy(self.model.generation_config)
# Update the generation config with task specific params if they exist
# NOTE: `prefix` is pipeline-specific and doesn't exist in the generation config.
task_specific_params = self.model.config.task_specific_params
if task_specific_params is not None and task in task_specific_params:
self.model.config.update(task_specific_params.get(task))
if self.model.can_generate():
self.model.generation_config.update(**task_specific_params.get(task))
# Pipelines calling `generate`: if the tokenizer has a pad token but the model doesn't, set it in the
# forward params so that `generate` is aware of the pad token.
this_task_params = task_specific_params.get(task)
if "prefix" in this_task_params:
self.prefix = this_task_params.pop("prefix")
self.generation_config.update(**this_task_params)
# If the tokenizer has a pad token but the model doesn't, set it so that `generate` is aware of it.
if (
self.tokenizer is not None
and self.model.can_generate()
and self.tokenizer.pad_token_id is not None
and self.model.generation_config.pad_token_id is None
and self.generation_config.pad_token_id is None
):
self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id
self.generation_config.pad_token_id = self.tokenizer.pad_token_id
self.call_count = 0
self._batch_size = kwargs.pop("batch_size", None)

View File

@@ -429,6 +429,10 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
is_last = model_inputs.pop("is_last", False)
if self.model_type == ModelType.VisionEncoderDecoder:
# User-defined `generation_config` passed to the pipeline call take precedence
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config
model_outputs = self.model.generate(**model_inputs, **generate_kwargs)
else:
model_outputs = self.model(**model_inputs)

View File

@@ -181,6 +181,10 @@ class ImageToTextPipeline(Pipeline):
):
model_inputs["input_ids"] = None
# User-defined `generation_config` passed to the pipeline call take precedence
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config
# FIXME: We need to pop here due to a difference in how `generation.py` and `generation.tf_utils.py`
# parse inputs. In the Tensorflow version, `generate` raises an error if we don't use `input_ids` whereas
# the PyTorch version matches it with `self.model.main_input_name` or `self.model.encoder.main_input_name`

View File

@@ -385,6 +385,10 @@ class TableQuestionAnsweringPipeline(Pipeline):
else:
outputs = self.batch_inference(**model_inputs)
else:
# User-defined `generation_config` passed to the pipeline call take precedence
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config
outputs = self.model.generate(**model_inputs, **generate_kwargs)
model_outputs = {"model_inputs": model_inputs, "table": table, "outputs": outputs}
return model_outputs

View File

@@ -115,7 +115,7 @@ class Text2TextGenerationPipeline(Pipeline):
return True
def _parse_and_tokenize(self, *args, truncation):
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
prefix = self.prefix if self.prefix is not None else ""
if isinstance(args[0], list):
if self.tokenizer.pad_token_id is None:
raise ValueError("Please make sure that the tokenizer has a pad_token_id when using a batch input")
@@ -185,9 +185,14 @@ class Text2TextGenerationPipeline(Pipeline):
self.check_inputs(
input_length,
generate_kwargs.get("min_length", self.model.config.min_length),
generate_kwargs.get("max_length", self.model.config.max_length),
generate_kwargs.get("min_length", self.generation_config.min_length),
generate_kwargs.get("max_length", self.generation_config.max_length),
)
# User-defined `generation_config` passed to the pipeline call take precedence
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config
output_ids = self.model.generate(**model_inputs, **generate_kwargs)
out_b = output_ids.shape[0]
if self.framework == "pt":

View File

@@ -103,8 +103,8 @@ class TextGenerationPipeline(Pipeline):
# It also defines both some preprocess_kwargs and generate_kwargs
# which is why we cannot put them in their respective methods.
prefix = None
if self.model.config.prefix is not None:
prefix = self.model.config.prefix
if self.prefix is not None:
prefix = self.prefix
if prefix is None and self.model.__class__.__name__ in [
"XLNetLMHeadModel",
"TransfoXLLMHeadModel",
@@ -316,7 +316,7 @@ class TextGenerationPipeline(Pipeline):
if "max_new_tokens" in generate_kwargs:
new_tokens = generate_kwargs["max_new_tokens"]
else:
new_tokens = generate_kwargs.get("max_length", self.model.config.max_length) - cur_len
new_tokens = generate_kwargs.get("max_length", self.generation_config.max_length) - cur_len
if new_tokens < 0:
raise ValueError("We cannot infer how many new tokens are expected")
if cur_len + new_tokens > self.tokenizer.model_max_length:
@@ -354,7 +354,7 @@ class TextGenerationPipeline(Pipeline):
and generate_kwargs["generation_config"].max_new_tokens is not None
)
if not has_max_new_tokens:
generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length
generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.generation_config.max_length
generate_kwargs["max_length"] += prefix_length
has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
"generation_config" in generate_kwargs
@@ -363,7 +363,10 @@ class TextGenerationPipeline(Pipeline):
if not has_min_new_tokens and "min_length" in generate_kwargs:
generate_kwargs["min_length"] += prefix_length
# BS x SL
# User-defined `generation_config` passed to the pipeline call take precedence
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
out_b = generated_sequence.shape[0]
if self.framework == "pt":

View File

@@ -111,7 +111,7 @@ class TextToAudioPipeline(Pipeline):
if self.model.config.model_type == "bark":
# bark Tokenizer is called with BarkProcessor which uses those kwargs
new_kwargs = {
"max_length": self.model.generation_config.semantic_config.get("max_input_semantic_length", 256),
"max_length": self.generation_config.semantic_config.get("max_input_semantic_length", 256),
"add_special_tokens": False,
"return_attention_mask": True,
"return_token_type_ids": False,
@@ -137,6 +137,10 @@ class TextToAudioPipeline(Pipeline):
# we expect some kwargs to be additional tensors which need to be on the right device
generate_kwargs = self._ensure_tensor_on_device(generate_kwargs, device=self.device)
# User-defined `generation_config` passed to the pipeline call take precedence
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config
# generate_kwargs get priority over forward_params
forward_params.update(generate_kwargs)

View File

@@ -162,6 +162,10 @@ class VisualQuestionAnsweringPipeline(Pipeline):
def _forward(self, model_inputs, **generate_kwargs):
if self.model.can_generate():
# User-defined `generation_config` passed to the pipeline call take precedence
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config
model_outputs = self.model.generate(**model_inputs, **generate_kwargs)
else:
model_outputs = self.model(**model_inputs)

View File

@@ -31,6 +31,7 @@ from transformers import (
AutoTokenizer,
DistilBertForSequenceClassification,
MaskGenerationPipeline,
T5ForConditionalGeneration,
TextClassificationPipeline,
TextGenerationPipeline,
TFAutoModelForSequenceClassification,
@@ -234,6 +235,31 @@ class CommonPipelineTest(unittest.TestCase):
self.assertIsInstance(pipe, TextGenerationPipeline) # Assert successful load
@require_torch
def test_pipeline_with_task_parameters_no_side_effects(self):
"""
Regression test: certain pipeline flags, like `task`, modified the model configuration, causing unexpected
side-effects
"""
# This checkpoint has task-specific parameters that will modify the behavior of the pipeline
model = T5ForConditionalGeneration.from_pretrained("t5-small")
self.assertTrue(model.config.num_beams == 1)
# The task-specific parameters used to cause side-effects on `model.config` -- not anymore
pipe = pipeline(model=model, tokenizer=AutoTokenizer.from_pretrained("t5-small"), task="translation_en_to_de")
self.assertTrue(model.config.num_beams == 1)
self.assertTrue(model.generation_config.num_beams == 1)
# Under the hood: we now store a generation config in the pipeline. This generation config stores the
# task-specific paremeters.
self.assertTrue(pipe.generation_config.num_beams == 4)
# We can confirm that the task-specific parameters have an effect. (In this case, the default is `num_beams=1`,
# which would crash when `num_return_sequences=4` is passed.)
pipe("Hugging Face doesn't sell hugs.", num_return_sequences=4)
with self.assertRaises(ValueError):
pipe("Hugging Face doesn't sell hugs.", num_return_sequences=4, num_beams=1)
@is_pipeline_test
class PipelineScikitCompatTest(unittest.TestCase):

View File

@@ -1715,6 +1715,38 @@ class ModelUtilsTest(TestCasePlus):
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
)
def test_save_and_load_config_with_custom_generation(self):
"""
Regression test for the ability to save and load a config with a custom generation kwarg (i.e. a parameter
that gets moved to the generation config and reset on the model config)
"""
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
# The default for `num_beams` is 1 and `early_stopping` is False
self.assertTrue(model.config.num_beams == 1)
self.assertTrue(model.config.early_stopping is False)
# When we save the model, this custom parameter should be moved to the generation config AND the model
# config should contain `None`
model.config.num_beams = 2
model.config.early_stopping = True
self.assertTrue(model.generation_config.num_beams == 1) # unmodified generation config
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
new_model = T5ForConditionalGeneration.from_pretrained(tmp_dir)
# moved to generation config
self.assertTrue(new_model.generation_config.num_beams == 2)
self.assertTrue(new_model.generation_config.early_stopping is True)
# reset in the model config
self.assertTrue(new_model.config.num_beams is None)
self.assertTrue(new_model.config.early_stopping is None)
# Sanity check: We can run `generate` with the new model without any warnings
random_ids = torch.randint(0, 100, (1, 5))
with warnings.catch_warnings(record=True) as w:
new_model.generate(random_ids, max_new_tokens=3)
self.assertTrue(len(w) == 0)
@slow
@require_torch