Adding handle_long_generation paramters for text-generation pipeline. (#14118)

* Adding `handle_long_generation` paramters for `text-generation` pipeline.

* More error handling

* Fixing tests by dropping tf support on this functionality, it needs

`max_new_tokens` to make it possible to understand user's intent.
Otherwise, `max_length` == `tokenizer.model_max_length` <
input_ids.shape[0].

* Fixing doc ?

* Doc ?

* Remove link from doc.

* Catched an issue on roberta.

* Damn doc.

* Non BC proposal ?

* Cleaning the fix ?

* Finally using only a test override.

* Don't need to modify this.

* Bad print.
This commit is contained in:
Nicolas Patry
2021-10-29 15:29:28 +02:00
committed by GitHub
parent d37f1fb8ba
commit dc540dd316
4 changed files with 68 additions and 4 deletions

View File

@@ -143,7 +143,9 @@ class PipelineTestCaseMeta(type):
try:
tokenizer = get_tiny_tokenizer_from_checkpoint(checkpoint)
# XLNet actually defines it as -1.
if (
if model.config.__class__.__name__ == "RobertaConfig":
tokenizer.model_max_length = model.config.max_position_embeddings - 2
elif (
hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings > 0
):

View File

@@ -123,3 +123,24 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
else:
with self.assertRaises((ValueError, AssertionError)):
outputs = text_generator("")
if text_generator.framework == "tf":
# TF generation does not support max_new_tokens, and it's impossible
# to control long generation with only max_length without
# fancy calculation, dismissing tests for now.
return
# We don't care about infinite range models.
# They already work.
if tokenizer.model_max_length < 10000:
# Handling of large generations
with self.assertRaises((RuntimeError, IndexError, ValueError, AssertionError)):
text_generator("This is a test" * 500, max_new_tokens=20)
outputs = text_generator("This is a test" * 500, handle_long_generation="hole", max_new_tokens=20)
# Hole strategy cannot work
with self.assertRaises(ValueError):
text_generator(
"This is a test" * 500,
handle_long_generation="hole",
max_new_tokens=tokenizer.model_max_length + 10,
)