Output dicts support in text generation pipeline (#35092)
* Support for generate_argument: return_dict_in_generate=True, instead of returning a error * fix: call test with return_dict_in_generate=True * fix: Only import torch if it is present * update: Encapsulate output_dict changes * fix: added back original comments --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
@@ -653,6 +653,31 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
_ = text_generator(prompt, max_length=10)
|
||||
self.assertNotIn(logger_msg, cl.out)
|
||||
|
||||
def test_return_dict_in_generate(self):
|
||||
text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-gpt2", max_new_tokens=16)
|
||||
out = text_generator(
|
||||
["This is great !", "Something else"], return_dict_in_generate=True, output_logits=True, output_scores=True
|
||||
)
|
||||
self.assertEqual(
|
||||
out,
|
||||
[
|
||||
[
|
||||
{
|
||||
"generated_text": ANY(str),
|
||||
"logits": ANY(list),
|
||||
"scores": ANY(list),
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"generated_text": ANY(str),
|
||||
"logits": ANY(list),
|
||||
"scores": ANY(list),
|
||||
},
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_pipeline_assisted_generation(self):
|
||||
"""Tests that we can run assisted generation in the pipeline"""
|
||||
|
||||
Reference in New Issue
Block a user