Pipeline: no side-effects on model.config and model.generation_config 🔫 (#33480)
This commit is contained in:
@@ -1229,6 +1229,10 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
"""
|
"""
|
||||||
config_dict = model_config.to_dict()
|
config_dict = model_config.to_dict()
|
||||||
config_dict.pop("_from_model_config", None)
|
config_dict.pop("_from_model_config", None)
|
||||||
|
|
||||||
|
# Removes all `None` from the model config dict -- this lets the generation config defaults to take hold
|
||||||
|
config_dict = {key: value for key, value in config_dict.items() if value is not None}
|
||||||
|
|
||||||
generation_config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)
|
generation_config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)
|
||||||
|
|
||||||
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
|
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
|
||||||
|
|||||||
@@ -1334,23 +1334,26 @@ class GenerationMixin:
|
|||||||
# the following conditions must be met
|
# the following conditions must be met
|
||||||
# 1) the generation config must have been created from the model config (`_from_model_config` field);
|
# 1) the generation config must have been created from the model config (`_from_model_config` field);
|
||||||
# 2) the generation config must have seen no modification since its creation (the hash is the same);
|
# 2) the generation config must have seen no modification since its creation (the hash is the same);
|
||||||
# 3) the user must have set generation parameters in the model config.
|
# 3) there are non-default generation parameters in the model config.
|
||||||
|
# 4) the user must have set new generation parameters in the model config.
|
||||||
# NOTE: `torch.compile` can't compile `hash`, this legacy support is disabled with compilation.
|
# NOTE: `torch.compile` can't compile `hash`, this legacy support is disabled with compilation.
|
||||||
if (
|
if (
|
||||||
not is_torchdynamo_compiling()
|
not is_torchdynamo_compiling()
|
||||||
and self.generation_config._from_model_config # 1)
|
and self.generation_config._from_model_config # 1)
|
||||||
and self.generation_config._original_object_hash == hash(self.generation_config) # 2)
|
and self.generation_config._original_object_hash == hash(self.generation_config) # 2)
|
||||||
|
and len(self.config._get_non_default_generation_parameters()) > 0 # 3)
|
||||||
):
|
):
|
||||||
new_generation_config = GenerationConfig.from_model_config(self.config)
|
new_generation_config = GenerationConfig.from_model_config(self.config)
|
||||||
if new_generation_config != self.generation_config: # 3)
|
if new_generation_config != self.generation_config: # 4)
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"You have modified the pretrained model configuration to control generation. This is a"
|
"You have modified the pretrained model configuration to control generation. This is a"
|
||||||
" deprecated strategy to control generation and will be removed soon, in a future version."
|
" deprecated strategy to control generation and will be removed in v5."
|
||||||
" Please use and modify the model generation configuration (see"
|
" Please use and modify the model generation configuration (see"
|
||||||
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
|
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )",
|
||||||
|
UserWarning,
|
||||||
)
|
)
|
||||||
self.generation_config = new_generation_config
|
self.generation_config = new_generation_config
|
||||||
using_model_generation_config = True
|
|
||||||
generation_config = self.generation_config
|
generation_config = self.generation_config
|
||||||
using_model_generation_config = True
|
using_model_generation_config = True
|
||||||
|
|
||||||
|
|||||||
@@ -501,6 +501,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
else:
|
else:
|
||||||
generate_kwargs["num_frames"] = num_frames
|
generate_kwargs["num_frames"] = num_frames
|
||||||
|
|
||||||
|
# User-defined `generation_config` passed to the pipeline call take precedence
|
||||||
|
if "generation_config" not in generate_kwargs:
|
||||||
|
generate_kwargs["generation_config"] = self.generation_config
|
||||||
|
|
||||||
tokens = self.model.generate(
|
tokens = self.model.generate(
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import collections
|
import collections
|
||||||
|
import copy
|
||||||
import csv
|
import csv
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
@@ -899,22 +900,26 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
|
|||||||
):
|
):
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
|
|
||||||
# Update config and generation_config with task specific parameters
|
# If the model can generate, create a local generation config. This is done to avoid side-effects on the model
|
||||||
task_specific_params = self.model.config.task_specific_params
|
# as we apply local tweaks to the generation config.
|
||||||
if task_specific_params is not None and task in task_specific_params:
|
if self.model.can_generate():
|
||||||
self.model.config.update(task_specific_params.get(task))
|
self.prefix = self.model.config.prefix if hasattr(self.model.config, "prefix") else None
|
||||||
if self.model.can_generate():
|
self.generation_config = copy.deepcopy(self.model.generation_config)
|
||||||
self.model.generation_config.update(**task_specific_params.get(task))
|
# Update the generation config with task specific params if they exist
|
||||||
|
# NOTE: `prefix` is pipeline-specific and doesn't exist in the generation config.
|
||||||
# Pipelines calling `generate`: if the tokenizer has a pad token but the model doesn't, set it in the
|
task_specific_params = self.model.config.task_specific_params
|
||||||
# forward params so that `generate` is aware of the pad token.
|
if task_specific_params is not None and task in task_specific_params:
|
||||||
if (
|
this_task_params = task_specific_params.get(task)
|
||||||
self.tokenizer is not None
|
if "prefix" in this_task_params:
|
||||||
and self.model.can_generate()
|
self.prefix = this_task_params.pop("prefix")
|
||||||
and self.tokenizer.pad_token_id is not None
|
self.generation_config.update(**this_task_params)
|
||||||
and self.model.generation_config.pad_token_id is None
|
# If the tokenizer has a pad token but the model doesn't, set it so that `generate` is aware of it.
|
||||||
):
|
if (
|
||||||
self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id
|
self.tokenizer is not None
|
||||||
|
and self.tokenizer.pad_token_id is not None
|
||||||
|
and self.generation_config.pad_token_id is None
|
||||||
|
):
|
||||||
|
self.generation_config.pad_token_id = self.tokenizer.pad_token_id
|
||||||
|
|
||||||
self.call_count = 0
|
self.call_count = 0
|
||||||
self._batch_size = kwargs.pop("batch_size", None)
|
self._batch_size = kwargs.pop("batch_size", None)
|
||||||
|
|||||||
@@ -429,6 +429,10 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
|
|||||||
is_last = model_inputs.pop("is_last", False)
|
is_last = model_inputs.pop("is_last", False)
|
||||||
|
|
||||||
if self.model_type == ModelType.VisionEncoderDecoder:
|
if self.model_type == ModelType.VisionEncoderDecoder:
|
||||||
|
# User-defined `generation_config` passed to the pipeline call take precedence
|
||||||
|
if "generation_config" not in generate_kwargs:
|
||||||
|
generate_kwargs["generation_config"] = self.generation_config
|
||||||
|
|
||||||
model_outputs = self.model.generate(**model_inputs, **generate_kwargs)
|
model_outputs = self.model.generate(**model_inputs, **generate_kwargs)
|
||||||
else:
|
else:
|
||||||
model_outputs = self.model(**model_inputs)
|
model_outputs = self.model(**model_inputs)
|
||||||
|
|||||||
@@ -181,6 +181,10 @@ class ImageToTextPipeline(Pipeline):
|
|||||||
):
|
):
|
||||||
model_inputs["input_ids"] = None
|
model_inputs["input_ids"] = None
|
||||||
|
|
||||||
|
# User-defined `generation_config` passed to the pipeline call take precedence
|
||||||
|
if "generation_config" not in generate_kwargs:
|
||||||
|
generate_kwargs["generation_config"] = self.generation_config
|
||||||
|
|
||||||
# FIXME: We need to pop here due to a difference in how `generation.py` and `generation.tf_utils.py`
|
# FIXME: We need to pop here due to a difference in how `generation.py` and `generation.tf_utils.py`
|
||||||
# parse inputs. In the Tensorflow version, `generate` raises an error if we don't use `input_ids` whereas
|
# parse inputs. In the Tensorflow version, `generate` raises an error if we don't use `input_ids` whereas
|
||||||
# the PyTorch version matches it with `self.model.main_input_name` or `self.model.encoder.main_input_name`
|
# the PyTorch version matches it with `self.model.main_input_name` or `self.model.encoder.main_input_name`
|
||||||
|
|||||||
@@ -385,6 +385,10 @@ class TableQuestionAnsweringPipeline(Pipeline):
|
|||||||
else:
|
else:
|
||||||
outputs = self.batch_inference(**model_inputs)
|
outputs = self.batch_inference(**model_inputs)
|
||||||
else:
|
else:
|
||||||
|
# User-defined `generation_config` passed to the pipeline call take precedence
|
||||||
|
if "generation_config" not in generate_kwargs:
|
||||||
|
generate_kwargs["generation_config"] = self.generation_config
|
||||||
|
|
||||||
outputs = self.model.generate(**model_inputs, **generate_kwargs)
|
outputs = self.model.generate(**model_inputs, **generate_kwargs)
|
||||||
model_outputs = {"model_inputs": model_inputs, "table": table, "outputs": outputs}
|
model_outputs = {"model_inputs": model_inputs, "table": table, "outputs": outputs}
|
||||||
return model_outputs
|
return model_outputs
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ class Text2TextGenerationPipeline(Pipeline):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def _parse_and_tokenize(self, *args, truncation):
|
def _parse_and_tokenize(self, *args, truncation):
|
||||||
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
|
prefix = self.prefix if self.prefix is not None else ""
|
||||||
if isinstance(args[0], list):
|
if isinstance(args[0], list):
|
||||||
if self.tokenizer.pad_token_id is None:
|
if self.tokenizer.pad_token_id is None:
|
||||||
raise ValueError("Please make sure that the tokenizer has a pad_token_id when using a batch input")
|
raise ValueError("Please make sure that the tokenizer has a pad_token_id when using a batch input")
|
||||||
@@ -185,9 +185,14 @@ class Text2TextGenerationPipeline(Pipeline):
|
|||||||
|
|
||||||
self.check_inputs(
|
self.check_inputs(
|
||||||
input_length,
|
input_length,
|
||||||
generate_kwargs.get("min_length", self.model.config.min_length),
|
generate_kwargs.get("min_length", self.generation_config.min_length),
|
||||||
generate_kwargs.get("max_length", self.model.config.max_length),
|
generate_kwargs.get("max_length", self.generation_config.max_length),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# User-defined `generation_config` passed to the pipeline call take precedence
|
||||||
|
if "generation_config" not in generate_kwargs:
|
||||||
|
generate_kwargs["generation_config"] = self.generation_config
|
||||||
|
|
||||||
output_ids = self.model.generate(**model_inputs, **generate_kwargs)
|
output_ids = self.model.generate(**model_inputs, **generate_kwargs)
|
||||||
out_b = output_ids.shape[0]
|
out_b = output_ids.shape[0]
|
||||||
if self.framework == "pt":
|
if self.framework == "pt":
|
||||||
|
|||||||
@@ -103,8 +103,8 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
# It also defines both some preprocess_kwargs and generate_kwargs
|
# It also defines both some preprocess_kwargs and generate_kwargs
|
||||||
# which is why we cannot put them in their respective methods.
|
# which is why we cannot put them in their respective methods.
|
||||||
prefix = None
|
prefix = None
|
||||||
if self.model.config.prefix is not None:
|
if self.prefix is not None:
|
||||||
prefix = self.model.config.prefix
|
prefix = self.prefix
|
||||||
if prefix is None and self.model.__class__.__name__ in [
|
if prefix is None and self.model.__class__.__name__ in [
|
||||||
"XLNetLMHeadModel",
|
"XLNetLMHeadModel",
|
||||||
"TransfoXLLMHeadModel",
|
"TransfoXLLMHeadModel",
|
||||||
@@ -316,7 +316,7 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
if "max_new_tokens" in generate_kwargs:
|
if "max_new_tokens" in generate_kwargs:
|
||||||
new_tokens = generate_kwargs["max_new_tokens"]
|
new_tokens = generate_kwargs["max_new_tokens"]
|
||||||
else:
|
else:
|
||||||
new_tokens = generate_kwargs.get("max_length", self.model.config.max_length) - cur_len
|
new_tokens = generate_kwargs.get("max_length", self.generation_config.max_length) - cur_len
|
||||||
if new_tokens < 0:
|
if new_tokens < 0:
|
||||||
raise ValueError("We cannot infer how many new tokens are expected")
|
raise ValueError("We cannot infer how many new tokens are expected")
|
||||||
if cur_len + new_tokens > self.tokenizer.model_max_length:
|
if cur_len + new_tokens > self.tokenizer.model_max_length:
|
||||||
@@ -354,7 +354,7 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
and generate_kwargs["generation_config"].max_new_tokens is not None
|
and generate_kwargs["generation_config"].max_new_tokens is not None
|
||||||
)
|
)
|
||||||
if not has_max_new_tokens:
|
if not has_max_new_tokens:
|
||||||
generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length
|
generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.generation_config.max_length
|
||||||
generate_kwargs["max_length"] += prefix_length
|
generate_kwargs["max_length"] += prefix_length
|
||||||
has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
|
has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
|
||||||
"generation_config" in generate_kwargs
|
"generation_config" in generate_kwargs
|
||||||
@@ -363,7 +363,10 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
if not has_min_new_tokens and "min_length" in generate_kwargs:
|
if not has_min_new_tokens and "min_length" in generate_kwargs:
|
||||||
generate_kwargs["min_length"] += prefix_length
|
generate_kwargs["min_length"] += prefix_length
|
||||||
|
|
||||||
# BS x SL
|
# User-defined `generation_config` passed to the pipeline call take precedence
|
||||||
|
if "generation_config" not in generate_kwargs:
|
||||||
|
generate_kwargs["generation_config"] = self.generation_config
|
||||||
|
|
||||||
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
|
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
|
||||||
out_b = generated_sequence.shape[0]
|
out_b = generated_sequence.shape[0]
|
||||||
if self.framework == "pt":
|
if self.framework == "pt":
|
||||||
|
|||||||
@@ -111,7 +111,7 @@ class TextToAudioPipeline(Pipeline):
|
|||||||
if self.model.config.model_type == "bark":
|
if self.model.config.model_type == "bark":
|
||||||
# bark Tokenizer is called with BarkProcessor which uses those kwargs
|
# bark Tokenizer is called with BarkProcessor which uses those kwargs
|
||||||
new_kwargs = {
|
new_kwargs = {
|
||||||
"max_length": self.model.generation_config.semantic_config.get("max_input_semantic_length", 256),
|
"max_length": self.generation_config.semantic_config.get("max_input_semantic_length", 256),
|
||||||
"add_special_tokens": False,
|
"add_special_tokens": False,
|
||||||
"return_attention_mask": True,
|
"return_attention_mask": True,
|
||||||
"return_token_type_ids": False,
|
"return_token_type_ids": False,
|
||||||
@@ -137,6 +137,10 @@ class TextToAudioPipeline(Pipeline):
|
|||||||
# we expect some kwargs to be additional tensors which need to be on the right device
|
# we expect some kwargs to be additional tensors which need to be on the right device
|
||||||
generate_kwargs = self._ensure_tensor_on_device(generate_kwargs, device=self.device)
|
generate_kwargs = self._ensure_tensor_on_device(generate_kwargs, device=self.device)
|
||||||
|
|
||||||
|
# User-defined `generation_config` passed to the pipeline call take precedence
|
||||||
|
if "generation_config" not in generate_kwargs:
|
||||||
|
generate_kwargs["generation_config"] = self.generation_config
|
||||||
|
|
||||||
# generate_kwargs get priority over forward_params
|
# generate_kwargs get priority over forward_params
|
||||||
forward_params.update(generate_kwargs)
|
forward_params.update(generate_kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -162,6 +162,10 @@ class VisualQuestionAnsweringPipeline(Pipeline):
|
|||||||
|
|
||||||
def _forward(self, model_inputs, **generate_kwargs):
|
def _forward(self, model_inputs, **generate_kwargs):
|
||||||
if self.model.can_generate():
|
if self.model.can_generate():
|
||||||
|
# User-defined `generation_config` passed to the pipeline call take precedence
|
||||||
|
if "generation_config" not in generate_kwargs:
|
||||||
|
generate_kwargs["generation_config"] = self.generation_config
|
||||||
|
|
||||||
model_outputs = self.model.generate(**model_inputs, **generate_kwargs)
|
model_outputs = self.model.generate(**model_inputs, **generate_kwargs)
|
||||||
else:
|
else:
|
||||||
model_outputs = self.model(**model_inputs)
|
model_outputs = self.model(**model_inputs)
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from transformers import (
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
DistilBertForSequenceClassification,
|
DistilBertForSequenceClassification,
|
||||||
MaskGenerationPipeline,
|
MaskGenerationPipeline,
|
||||||
|
T5ForConditionalGeneration,
|
||||||
TextClassificationPipeline,
|
TextClassificationPipeline,
|
||||||
TextGenerationPipeline,
|
TextGenerationPipeline,
|
||||||
TFAutoModelForSequenceClassification,
|
TFAutoModelForSequenceClassification,
|
||||||
@@ -234,6 +235,31 @@ class CommonPipelineTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertIsInstance(pipe, TextGenerationPipeline) # Assert successful load
|
self.assertIsInstance(pipe, TextGenerationPipeline) # Assert successful load
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_pipeline_with_task_parameters_no_side_effects(self):
|
||||||
|
"""
|
||||||
|
Regression test: certain pipeline flags, like `task`, modified the model configuration, causing unexpected
|
||||||
|
side-effects
|
||||||
|
"""
|
||||||
|
# This checkpoint has task-specific parameters that will modify the behavior of the pipeline
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained("t5-small")
|
||||||
|
self.assertTrue(model.config.num_beams == 1)
|
||||||
|
|
||||||
|
# The task-specific parameters used to cause side-effects on `model.config` -- not anymore
|
||||||
|
pipe = pipeline(model=model, tokenizer=AutoTokenizer.from_pretrained("t5-small"), task="translation_en_to_de")
|
||||||
|
self.assertTrue(model.config.num_beams == 1)
|
||||||
|
self.assertTrue(model.generation_config.num_beams == 1)
|
||||||
|
|
||||||
|
# Under the hood: we now store a generation config in the pipeline. This generation config stores the
|
||||||
|
# task-specific paremeters.
|
||||||
|
self.assertTrue(pipe.generation_config.num_beams == 4)
|
||||||
|
|
||||||
|
# We can confirm that the task-specific parameters have an effect. (In this case, the default is `num_beams=1`,
|
||||||
|
# which would crash when `num_return_sequences=4` is passed.)
|
||||||
|
pipe("Hugging Face doesn't sell hugs.", num_return_sequences=4)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
pipe("Hugging Face doesn't sell hugs.", num_return_sequences=4, num_beams=1)
|
||||||
|
|
||||||
|
|
||||||
@is_pipeline_test
|
@is_pipeline_test
|
||||||
class PipelineScikitCompatTest(unittest.TestCase):
|
class PipelineScikitCompatTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -1715,6 +1715,38 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
|
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_save_and_load_config_with_custom_generation(self):
|
||||||
|
"""
|
||||||
|
Regression test for the ability to save and load a config with a custom generation kwarg (i.e. a parameter
|
||||||
|
that gets moved to the generation config and reset on the model config)
|
||||||
|
"""
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
|
||||||
|
|
||||||
|
# The default for `num_beams` is 1 and `early_stopping` is False
|
||||||
|
self.assertTrue(model.config.num_beams == 1)
|
||||||
|
self.assertTrue(model.config.early_stopping is False)
|
||||||
|
|
||||||
|
# When we save the model, this custom parameter should be moved to the generation config AND the model
|
||||||
|
# config should contain `None`
|
||||||
|
model.config.num_beams = 2
|
||||||
|
model.config.early_stopping = True
|
||||||
|
self.assertTrue(model.generation_config.num_beams == 1) # unmodified generation config
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir)
|
||||||
|
new_model = T5ForConditionalGeneration.from_pretrained(tmp_dir)
|
||||||
|
# moved to generation config
|
||||||
|
self.assertTrue(new_model.generation_config.num_beams == 2)
|
||||||
|
self.assertTrue(new_model.generation_config.early_stopping is True)
|
||||||
|
# reset in the model config
|
||||||
|
self.assertTrue(new_model.config.num_beams is None)
|
||||||
|
self.assertTrue(new_model.config.early_stopping is None)
|
||||||
|
|
||||||
|
# Sanity check: We can run `generate` with the new model without any warnings
|
||||||
|
random_ids = torch.randint(0, 100, (1, 5))
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
new_model.generate(random_ids, max_new_tokens=3)
|
||||||
|
self.assertTrue(len(w) == 0)
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
Reference in New Issue
Block a user