fix image2test args forwarding (#19648)
* fix image2test args forwarding * fix issues * Proposing the update to the PR. * Fixup. Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
@@ -44,8 +44,20 @@ class ImageToTextPipeline(Pipeline):
|
||||
TF_MODEL_FOR_VISION_2_SEQ_MAPPING if self.framework == "tf" else MODEL_FOR_VISION_2_SEQ_MAPPING
|
||||
)
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
return {}, {}, {}
|
||||
def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None):
|
||||
forward_kwargs = {}
|
||||
if generate_kwargs is not None:
|
||||
forward_kwargs["generate_kwargs"] = generate_kwargs
|
||||
if max_new_tokens is not None:
|
||||
if "generate_kwargs" not in forward_kwargs:
|
||||
forward_kwargs["generate_kwargs"] = {}
|
||||
if "max_new_tokens" in forward_kwargs["generate_kwargs"]:
|
||||
raise ValueError(
|
||||
"'max_new_tokens' is defined twice, once in 'generate_kwargs' and once as a direct parameter,"
|
||||
" please use only one"
|
||||
)
|
||||
forward_kwargs["generate_kwargs"]["max_new_tokens"] = max_new_tokens
|
||||
return {}, forward_kwargs, {}
|
||||
|
||||
def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
|
||||
"""
|
||||
@@ -61,6 +73,12 @@ class ImageToTextPipeline(Pipeline):
|
||||
|
||||
The pipeline accepts either a single image or a batch of images.
|
||||
|
||||
max_new_tokens (`int`, *optional*):
|
||||
The amount of maximum tokens to generate. By default it will use `generate` default.
|
||||
|
||||
generate_kwargs (`Dict`, *optional*):
|
||||
Pass it to send all of these arguments directly to `generate` allowing full control of this function.
|
||||
|
||||
Return:
|
||||
A list or a list of list of `dict`: Each result comes as a dictionary with the following key:
|
||||
|
||||
@@ -73,13 +91,15 @@ class ImageToTextPipeline(Pipeline):
|
||||
model_inputs = self.feature_extractor(images=image, return_tensors=self.framework)
|
||||
return model_inputs
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
def _forward(self, model_inputs, generate_kwargs=None):
|
||||
if generate_kwargs is None:
|
||||
generate_kwargs = {}
|
||||
# FIXME: We need to pop here due to a difference in how `generation_utils.py` and `generation_tf_utils.py`
|
||||
# parse inputs. In the Tensorflow version, `generate` raises an error if we don't use `input_ids` whereas
|
||||
# the PyTorch version matches it with `self.model.main_input_name` or `self.model.encoder.main_input_name`
|
||||
# in the `_prepare_model_inputs` method.
|
||||
inputs = model_inputs.pop(self.model.main_input_name)
|
||||
model_outputs = self.model.generate(inputs, **model_inputs)
|
||||
model_outputs = self.model.generate(inputs, **model_inputs, **generate_kwargs)
|
||||
return model_outputs
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
|
||||
@@ -86,6 +86,12 @@ class ImageToTextPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta
|
||||
],
|
||||
)
|
||||
|
||||
outputs = pipe(image, max_new_tokens=1)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[{"generated_text": "growth"}],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
pipe = pipeline("image-to-text", model="hf-internal-testing/tiny-random-vit-gpt2")
|
||||
|
||||
Reference in New Issue
Block a user