From 06dd597552a78c53ce52ec1acf5ff3ab3ece82ea Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 1 Apr 2020 21:59:12 +0200 Subject: [PATCH] fix bug in warnings T5 pipelines (#3545) --- src/transformers/pipelines.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index 89caf192cd..d6c017749a 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -1235,17 +1235,19 @@ class SummarizationPipeline(Pipeline): elif self.framework == "tf": input_length = tf.shape(inputs["input_ids"])[-1].numpy() - if input_length < self.model.config.min_length // 2: + min_length = generate_kwargs.get("min_length", self.model.config.min_length) + if input_length < min_length // 2: logger.warning( "Your min_length is set to {}, but you input_length is only {}. You might consider decreasing min_length manually, e.g. summarizer('...', min_length=10)".format( - self.model.config.min_length, input_length + min_length, input_length ) ) - if input_length < self.model.config.max_length: + max_length = generate_kwargs.get("max_length", self.model.config.max_length) + if input_length < max_length: logger.warning( "Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format( - self.model.config.max_length, input_length + max_length, input_length ) ) @@ -1349,10 +1351,11 @@ class TranslationPipeline(Pipeline): elif self.framework == "tf": input_length = tf.shape(inputs["input_ids"])[-1].numpy() - if input_length > 0.9 * self.model.config.max_length: + max_length = generate_kwargs.get("max_length", self.model.config.max_length) + if input_length > 0.9 * max_length: logger.warning( "Your input_length: {} is bigger than 0.9 * max_length: {}. You might consider increasing your max_length manually, e.g. translator('...', max_length=400)".format( - input_length, self.model.config.max_length + input_length, max_length ) )