Adding handle_long_generation paramters for text-generation pipeline. (#14118)
* Adding `handle_long_generation` paramters for `text-generation` pipeline. * More error handling * Fixing tests by dropping tf support on this functionality, it needs `max_new_tokens` to make it possible to understand user's intent. Otherwise, `max_length` == `tokenizer.model_max_length` < input_ids.shape[0]. * Fixing doc ? * Doc ? * Remove link from doc. * Catched an issue on roberta. * Damn doc. * Non BC proposal ? * Cleaning the fix ? * Finally using only a test override. * Don't need to modify this. * Bad print.
This commit is contained in:
@@ -123,3 +123,24 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
|
||||
else:
|
||||
with self.assertRaises((ValueError, AssertionError)):
|
||||
outputs = text_generator("")
|
||||
|
||||
if text_generator.framework == "tf":
|
||||
# TF generation does not support max_new_tokens, and it's impossible
|
||||
# to control long generation with only max_length without
|
||||
# fancy calculation, dismissing tests for now.
|
||||
return
|
||||
# We don't care about infinite range models.
|
||||
# They already work.
|
||||
if tokenizer.model_max_length < 10000:
|
||||
# Handling of large generations
|
||||
with self.assertRaises((RuntimeError, IndexError, ValueError, AssertionError)):
|
||||
text_generator("This is a test" * 500, max_new_tokens=20)
|
||||
|
||||
outputs = text_generator("This is a test" * 500, handle_long_generation="hole", max_new_tokens=20)
|
||||
# Hole strategy cannot work
|
||||
with self.assertRaises(ValueError):
|
||||
text_generator(
|
||||
"This is a test" * 500,
|
||||
handle_long_generation="hole",
|
||||
max_new_tokens=tokenizer.model_max_length + 10,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user