From 3a08dc63fd788f768e1f16a97db14d0015368940 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 3 May 2023 14:43:17 +0100 Subject: [PATCH] Generate: better warnings with pipelines (#23128) --- src/transformers/pipelines/base.py | 4 +++- src/transformers/pipelines/text2text_generation.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index b728e94f34..de6c9a8ec4 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -803,10 +803,12 @@ class Pipeline(_ScikitCompat): self.torch_dtype = torch_dtype self.binary_output = binary_output - # Update config with task specific parameters + # Update config and generation_config with task specific parameters task_specific_params = self.model.config.task_specific_params if task_specific_params is not None and task in task_specific_params: self.model.config.update(task_specific_params.get(task)) + if self.model.can_generate(): + self.model.generation_config.update(**task_specific_params.get(task)) self.call_count = 0 self._batch_size = kwargs.pop("batch_size", None) diff --git a/src/transformers/pipelines/text2text_generation.py b/src/transformers/pipelines/text2text_generation.py index dbd45e6ff1..48df10336a 100644 --- a/src/transformers/pipelines/text2text_generation.py +++ b/src/transformers/pipelines/text2text_generation.py @@ -273,7 +273,8 @@ class SummarizationPipeline(Text2TextGenerationPipeline): if input_length < max_length: logger.warning( - f"Your max_length is set to {max_length}, but your input_length is only {input_length}. You might " + f"Your max_length is set to {max_length}, but your input_length is only {input_length}. Since this is " + "a summarization task, where outputs shorter than the input are typically wanted, you might " f"consider decreasing max_length manually, e.g. summarizer('...', max_length={input_length//2})" )