Extend pipelines for automodel tupels (#12025)
* fix_torch_device_generate_test * remove @ * finish * refactor * add test * fix test * Attempt at simplification. * Small fix. * Fixing non existing AutoModel for TF. * Naming. * Remove extra condition. Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -18,6 +18,8 @@ from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
BlenderbotSmallForConditionalGeneration,
|
||||
BlenderbotSmallTokenizer,
|
||||
Conversation,
|
||||
ConversationalPipeline,
|
||||
is_torch_available,
|
||||
@@ -389,3 +391,32 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
||||
self.assertEqual(result[0].generated_responses[1], "i don't have any plans yet. i'm not sure what to do yet.")
|
||||
self.assertEqual(result[1].past_user_inputs[1], "What's your name?")
|
||||
self.assertEqual(result[1].generated_responses[1], "i don't have a name, but i'm going to see a horror movie.")
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_from_pipeline_conversation(self):
|
||||
model_id = "facebook/blenderbot_small-90M"
|
||||
|
||||
# from model id
|
||||
conversation_agent_from_model_id = pipeline("conversational", model=model_id, tokenizer=model_id)
|
||||
|
||||
# from model object
|
||||
model = BlenderbotSmallForConditionalGeneration.from_pretrained(model_id)
|
||||
tokenizer = BlenderbotSmallTokenizer.from_pretrained(model_id)
|
||||
conversation_agent_from_model = pipeline("conversational", model=model, tokenizer=tokenizer)
|
||||
|
||||
conversation = Conversation("My name is Sarah and I live in London")
|
||||
conversation_copy = Conversation("My name is Sarah and I live in London")
|
||||
|
||||
result_model_id = conversation_agent_from_model_id([conversation])
|
||||
result_model = conversation_agent_from_model([conversation_copy])
|
||||
|
||||
# check for equality
|
||||
self.assertEqual(
|
||||
result_model_id.generated_responses[0],
|
||||
"hi sarah, i live in london as well. do you have any plans for the weekend?",
|
||||
)
|
||||
self.assertEqual(
|
||||
result_model_id.generated_responses[0],
|
||||
result_model.generated_responses[0],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user