From d3f4cef74d94d52e967e76a2918e8368ee2e3fe0 Mon Sep 17 00:00:00 2001 From: Rak Alexey Date: Mon, 24 Oct 2022 16:49:24 +0300 Subject: [PATCH] fix image2test args forwarding (#19648) * fix image2test args forwarding * fix issues * Proposing the update to the PR. * Fixup. Co-authored-by: Nicolas Patry --- src/transformers/pipelines/image_to_text.py | 28 ++++++++++++++++--- .../pipelines/test_pipelines_image_to_text.py | 6 ++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/src/transformers/pipelines/image_to_text.py b/src/transformers/pipelines/image_to_text.py index 934525533e..b4547518e6 100644 --- a/src/transformers/pipelines/image_to_text.py +++ b/src/transformers/pipelines/image_to_text.py @@ -44,8 +44,20 @@ class ImageToTextPipeline(Pipeline): TF_MODEL_FOR_VISION_2_SEQ_MAPPING if self.framework == "tf" else MODEL_FOR_VISION_2_SEQ_MAPPING ) - def _sanitize_parameters(self, **kwargs): - return {}, {}, {} + def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None): + forward_kwargs = {} + if generate_kwargs is not None: + forward_kwargs["generate_kwargs"] = generate_kwargs + if max_new_tokens is not None: + if "generate_kwargs" not in forward_kwargs: + forward_kwargs["generate_kwargs"] = {} + if "max_new_tokens" in forward_kwargs["generate_kwargs"]: + raise ValueError( + "'max_new_tokens' is defined twice, once in 'generate_kwargs' and once as a direct parameter," + " please use only one" + ) + forward_kwargs["generate_kwargs"]["max_new_tokens"] = max_new_tokens + return {}, forward_kwargs, {} def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs): """ @@ -61,6 +73,12 @@ class ImageToTextPipeline(Pipeline): The pipeline accepts either a single image or a batch of images. + max_new_tokens (`int`, *optional*): + The amount of maximum tokens to generate. By default it will use `generate` default. + + generate_kwargs (`Dict`, *optional*): + Pass it to send all of these arguments directly to `generate` allowing full control of this function. + Return: A list or a list of list of `dict`: Each result comes as a dictionary with the following key: @@ -73,13 +91,15 @@ class ImageToTextPipeline(Pipeline): model_inputs = self.feature_extractor(images=image, return_tensors=self.framework) return model_inputs - def _forward(self, model_inputs): + def _forward(self, model_inputs, generate_kwargs=None): + if generate_kwargs is None: + generate_kwargs = {} # FIXME: We need to pop here due to a difference in how `generation_utils.py` and `generation_tf_utils.py` # parse inputs. In the Tensorflow version, `generate` raises an error if we don't use `input_ids` whereas # the PyTorch version matches it with `self.model.main_input_name` or `self.model.encoder.main_input_name` # in the `_prepare_model_inputs` method. inputs = model_inputs.pop(self.model.main_input_name) - model_outputs = self.model.generate(inputs, **model_inputs) + model_outputs = self.model.generate(inputs, **model_inputs, **generate_kwargs) return model_outputs def postprocess(self, model_outputs): diff --git a/tests/pipelines/test_pipelines_image_to_text.py b/tests/pipelines/test_pipelines_image_to_text.py index 652c140ae5..0e1e805f9b 100644 --- a/tests/pipelines/test_pipelines_image_to_text.py +++ b/tests/pipelines/test_pipelines_image_to_text.py @@ -86,6 +86,12 @@ class ImageToTextPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta ], ) + outputs = pipe(image, max_new_tokens=1) + self.assertEqual( + outputs, + [{"generated_text": "growth"}], + ) + @require_torch def test_small_model_pt(self): pipe = pipeline("image-to-text", model="hf-internal-testing/tiny-random-vit-gpt2")