Fixing return type tensor with num_return_sequences>1. (#16828)

* Fixing return type tensor with `num_return_sequences>1`.

* Nit.
This commit is contained in:
Nicolas Patry
2022-04-20 16:11:51 +02:00
committed by GitHub
parent ff06b17791
commit e13a91fe60
4 changed files with 69 additions and 2 deletions

View File

@@ -56,6 +56,37 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
],
)
outputs = text_generator("This is a test", do_sample=True, num_return_sequences=2, return_tensors=True)
self.assertEqual(
outputs,
[
{"generated_token_ids": ANY(list)},
{"generated_token_ids": ANY(list)},
],
)
text_generator.tokenizer.pad_token_id = text_generator.model.config.eos_token_id
text_generator.tokenizer.pad_token = "<pad>"
outputs = text_generator(
["This is a test", "This is a second test"],
do_sample=True,
num_return_sequences=2,
batch_size=2,
return_tensors=True,
)
self.assertEqual(
outputs,
[
[
{"generated_token_ids": ANY(list)},
{"generated_token_ids": ANY(list)},
],
[
{"generated_token_ids": ANY(list)},
{"generated_token_ids": ANY(list)},
],
],
)
@require_tf
def test_small_model_tf(self):
text_generator = pipeline(task="text-generation", model="sshleifer/tiny-ctrl", framework="tf")