Fixing support batch_size and num_return_Sequences in text-generation pipeline (#15318)
* Fixing support `batch_size` and `num_return_Sequences` in `text-generation` pipeline And `text2text-generation` too. The bug was caused by the batch_size containing both the incoming batch **and** the generated `num_sequences`. The fix simply consists into splitting both of these again into different dimensions. * TF support. * Odd backward compatibility script in the way.
This commit is contained in:
@@ -40,6 +40,26 @@ class Text2TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
||||
# These are encoder decoder, they don't just append to incoming string
|
||||
self.assertFalse(outputs[0]["generated_text"].startswith("Something there"))
|
||||
|
||||
outputs = generator(["This is great !", "Something else"], num_return_sequences=2, do_sample=True)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
|
||||
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
|
||||
],
|
||||
)
|
||||
|
||||
outputs = generator(
|
||||
["This is great !", "Something else"], num_return_sequences=2, batch_size=2, do_sample=True
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
|
||||
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
|
||||
],
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
generator(4)
|
||||
|
||||
|
||||
@@ -113,6 +113,27 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
|
||||
self.assertEqual(outputs, [{"generated_text": ANY(str)}])
|
||||
self.assertTrue(outputs[0]["generated_text"].startswith("This is a test"))
|
||||
|
||||
outputs = text_generator(["This is great !", "Something else"], num_return_sequences=2, do_sample=True)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
|
||||
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
|
||||
],
|
||||
)
|
||||
|
||||
if text_generator.tokenizer.pad_token is not None:
|
||||
outputs = text_generator(
|
||||
["This is great !", "Something else"], num_return_sequences=2, batch_size=2, do_sample=True
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
|
||||
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
|
||||
],
|
||||
)
|
||||
|
||||
# Empty prompt is slighly special
|
||||
# it requires BOS token to exist.
|
||||
# Special case for Pegasus which will always append EOS so will
|
||||
|
||||
Reference in New Issue
Block a user