Generate: better warnings with pipelines (#23128)

This commit is contained in:
Joao Gante
2023-05-03 14:43:17 +01:00
committed by GitHub
parent 2a16d8b275
commit 3a08dc63fd
2 changed files with 5 additions and 2 deletions

View File

@@ -803,10 +803,12 @@ class Pipeline(_ScikitCompat):
self.torch_dtype = torch_dtype
self.binary_output = binary_output
# Update config with task specific parameters
# Update config and generation_config with task specific parameters
task_specific_params = self.model.config.task_specific_params
if task_specific_params is not None and task in task_specific_params:
self.model.config.update(task_specific_params.get(task))
if self.model.can_generate():
self.model.generation_config.update(**task_specific_params.get(task))
self.call_count = 0
self._batch_size = kwargs.pop("batch_size", None)

View File

@@ -273,7 +273,8 @@ class SummarizationPipeline(Text2TextGenerationPipeline):
if input_length < max_length:
logger.warning(
f"Your max_length is set to {max_length}, but your input_length is only {input_length}. You might "
f"Your max_length is set to {max_length}, but your input_length is only {input_length}. Since this is "
"a summarization task, where outputs shorter than the input are typically wanted, you might "
f"consider decreasing max_length manually, e.g. summarizer('...', max_length={input_length//2})"
)