From 366c03271e01c86e9a573bb64481f185de11ef29 Mon Sep 17 00:00:00 2001 From: thedamnedrhino Date: Mon, 15 Jan 2024 07:52:18 -0800 Subject: [PATCH] Tokenizer kwargs in textgeneration pipe (#28362) * added args to the pipeline * added test * more sensical tests * fixup * docs * typo ; * docs * made changes to support named args * fixed test * docs update * styles * docs * docs --- docs/source/en/preprocessing.md | 6 ++++ src/transformers/pipelines/text_generation.py | 30 +++++++++++++++++-- .../test_pipelines_text_generation.py | 16 ++++++++++ 3 files changed, 49 insertions(+), 3 deletions(-) diff --git a/docs/source/en/preprocessing.md b/docs/source/en/preprocessing.md index 743904cc99..4aa9030fe4 100644 --- a/docs/source/en/preprocessing.md +++ b/docs/source/en/preprocessing.md @@ -216,6 +216,12 @@ array([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + +Different pipelines support tokenizer arguments in their `__call__()` differently. `text-2-text-generation` pipelines support (i.e. pass on) +only `truncation`. `text-generation` pipelines support `max_length`, `truncation`, `padding` and `add_special_tokens`. +In `fill-mask` pipelines, tokenizer arguments can be passed in the `tokenizer_kwargs` argument (dictionary). + + ## Audio For audio tasks, you'll need a [feature extractor](main_classes/feature_extractor) to prepare your dataset for the model. The feature extractor is designed to extract features from raw audio data, and convert them into tensors. diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 109971d8ac..fe0a49a476 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -104,9 +104,20 @@ class TextGenerationPipeline(Pipeline): handle_long_generation=None, stop_sequence=None, add_special_tokens=False, + truncation=None, + padding=False, + max_length=None, **generate_kwargs, ): - preprocess_params = {"add_special_tokens": add_special_tokens} + preprocess_params = { + "add_special_tokens": add_special_tokens, + "truncation": truncation, + "padding": padding, + "max_length": max_length, + } + if max_length is not None: + generate_kwargs["max_length"] = max_length + if prefix is not None: preprocess_params["prefix"] = prefix if prefix: @@ -208,10 +219,23 @@ class TextGenerationPipeline(Pipeline): return super().__call__(text_inputs, **kwargs) def preprocess( - self, prompt_text, prefix="", handle_long_generation=None, add_special_tokens=False, **generate_kwargs + self, + prompt_text, + prefix="", + handle_long_generation=None, + add_special_tokens=False, + truncation=None, + padding=False, + max_length=None, + **generate_kwargs, ): inputs = self.tokenizer( - prefix + prompt_text, padding=False, add_special_tokens=add_special_tokens, return_tensors=self.framework + prefix + prompt_text, + return_tensors=self.framework, + truncation=truncation, + padding=padding, + max_length=max_length, + add_special_tokens=add_special_tokens, ) inputs["prompt_text"] = prompt_text diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index b80944e80e..bf4c1e9f9d 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -90,6 +90,22 @@ class TextGenerationPipelineTests(unittest.TestCase): {"generated_token_ids": ANY(list)}, ], ) + + ## -- test tokenizer_kwargs + test_str = "testing tokenizer kwargs. using truncation must result in a different generation." + output_str, output_str_with_truncation = ( + text_generator(test_str, do_sample=False, return_full_text=False)[0]["generated_text"], + text_generator( + test_str, + do_sample=False, + return_full_text=False, + truncation=True, + max_length=3, + )[0]["generated_text"], + ) + assert output_str != output_str_with_truncation # results must be different because one hd truncation + + # -- what is the point of this test? padding is hardcoded False in the pipeline anyway text_generator.tokenizer.pad_token_id = text_generator.model.config.eos_token_id text_generator.tokenizer.pad_token = "" outputs = text_generator(