Tokenizer kwargs in textgeneration pipe (#28362)
* added args to the pipeline * added test * more sensical tests * fixup * docs * typo ; * docs * made changes to support named args * fixed test * docs update * styles * docs * docs
This commit is contained in:
@@ -216,6 +216,12 @@ array([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
|
|||||||
</tf>
|
</tf>
|
||||||
</frameworkcontent>
|
</frameworkcontent>
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
Different pipelines support tokenizer arguments in their `__call__()` differently. `text-2-text-generation` pipelines support (i.e. pass on)
|
||||||
|
only `truncation`. `text-generation` pipelines support `max_length`, `truncation`, `padding` and `add_special_tokens`.
|
||||||
|
In `fill-mask` pipelines, tokenizer arguments can be passed in the `tokenizer_kwargs` argument (dictionary).
|
||||||
|
</Tip>
|
||||||
|
|
||||||
## Audio
|
## Audio
|
||||||
|
|
||||||
For audio tasks, you'll need a [feature extractor](main_classes/feature_extractor) to prepare your dataset for the model. The feature extractor is designed to extract features from raw audio data, and convert them into tensors.
|
For audio tasks, you'll need a [feature extractor](main_classes/feature_extractor) to prepare your dataset for the model. The feature extractor is designed to extract features from raw audio data, and convert them into tensors.
|
||||||
|
|||||||
@@ -104,9 +104,20 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
handle_long_generation=None,
|
handle_long_generation=None,
|
||||||
stop_sequence=None,
|
stop_sequence=None,
|
||||||
add_special_tokens=False,
|
add_special_tokens=False,
|
||||||
|
truncation=None,
|
||||||
|
padding=False,
|
||||||
|
max_length=None,
|
||||||
**generate_kwargs,
|
**generate_kwargs,
|
||||||
):
|
):
|
||||||
preprocess_params = {"add_special_tokens": add_special_tokens}
|
preprocess_params = {
|
||||||
|
"add_special_tokens": add_special_tokens,
|
||||||
|
"truncation": truncation,
|
||||||
|
"padding": padding,
|
||||||
|
"max_length": max_length,
|
||||||
|
}
|
||||||
|
if max_length is not None:
|
||||||
|
generate_kwargs["max_length"] = max_length
|
||||||
|
|
||||||
if prefix is not None:
|
if prefix is not None:
|
||||||
preprocess_params["prefix"] = prefix
|
preprocess_params["prefix"] = prefix
|
||||||
if prefix:
|
if prefix:
|
||||||
@@ -208,10 +219,23 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
return super().__call__(text_inputs, **kwargs)
|
return super().__call__(text_inputs, **kwargs)
|
||||||
|
|
||||||
def preprocess(
|
def preprocess(
|
||||||
self, prompt_text, prefix="", handle_long_generation=None, add_special_tokens=False, **generate_kwargs
|
self,
|
||||||
|
prompt_text,
|
||||||
|
prefix="",
|
||||||
|
handle_long_generation=None,
|
||||||
|
add_special_tokens=False,
|
||||||
|
truncation=None,
|
||||||
|
padding=False,
|
||||||
|
max_length=None,
|
||||||
|
**generate_kwargs,
|
||||||
):
|
):
|
||||||
inputs = self.tokenizer(
|
inputs = self.tokenizer(
|
||||||
prefix + prompt_text, padding=False, add_special_tokens=add_special_tokens, return_tensors=self.framework
|
prefix + prompt_text,
|
||||||
|
return_tensors=self.framework,
|
||||||
|
truncation=truncation,
|
||||||
|
padding=padding,
|
||||||
|
max_length=max_length,
|
||||||
|
add_special_tokens=add_special_tokens,
|
||||||
)
|
)
|
||||||
inputs["prompt_text"] = prompt_text
|
inputs["prompt_text"] = prompt_text
|
||||||
|
|
||||||
|
|||||||
@@ -90,6 +90,22 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
|||||||
{"generated_token_ids": ANY(list)},
|
{"generated_token_ids": ANY(list)},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## -- test tokenizer_kwargs
|
||||||
|
test_str = "testing tokenizer kwargs. using truncation must result in a different generation."
|
||||||
|
output_str, output_str_with_truncation = (
|
||||||
|
text_generator(test_str, do_sample=False, return_full_text=False)[0]["generated_text"],
|
||||||
|
text_generator(
|
||||||
|
test_str,
|
||||||
|
do_sample=False,
|
||||||
|
return_full_text=False,
|
||||||
|
truncation=True,
|
||||||
|
max_length=3,
|
||||||
|
)[0]["generated_text"],
|
||||||
|
)
|
||||||
|
assert output_str != output_str_with_truncation # results must be different because one hd truncation
|
||||||
|
|
||||||
|
# -- what is the point of this test? padding is hardcoded False in the pipeline anyway
|
||||||
text_generator.tokenizer.pad_token_id = text_generator.model.config.eos_token_id
|
text_generator.tokenizer.pad_token_id = text_generator.model.config.eos_token_id
|
||||||
text_generator.tokenizer.pad_token = "<pad>"
|
text_generator.tokenizer.pad_token = "<pad>"
|
||||||
outputs = text_generator(
|
outputs = text_generator(
|
||||||
|
|||||||
Reference in New Issue
Block a user