Generate: better warnings with pipelines (#23128)
This commit is contained in:
@@ -803,10 +803,12 @@ class Pipeline(_ScikitCompat):
|
||||
self.torch_dtype = torch_dtype
|
||||
self.binary_output = binary_output
|
||||
|
||||
# Update config with task specific parameters
|
||||
# Update config and generation_config with task specific parameters
|
||||
task_specific_params = self.model.config.task_specific_params
|
||||
if task_specific_params is not None and task in task_specific_params:
|
||||
self.model.config.update(task_specific_params.get(task))
|
||||
if self.model.can_generate():
|
||||
self.model.generation_config.update(**task_specific_params.get(task))
|
||||
|
||||
self.call_count = 0
|
||||
self._batch_size = kwargs.pop("batch_size", None)
|
||||
|
||||
@@ -273,7 +273,8 @@ class SummarizationPipeline(Text2TextGenerationPipeline):
|
||||
|
||||
if input_length < max_length:
|
||||
logger.warning(
|
||||
f"Your max_length is set to {max_length}, but your input_length is only {input_length}. You might "
|
||||
f"Your max_length is set to {max_length}, but your input_length is only {input_length}. Since this is "
|
||||
"a summarization task, where outputs shorter than the input are typically wanted, you might "
|
||||
f"consider decreasing max_length manually, e.g. summarizer('...', max_length={input_length//2})"
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user