Shorten the conversation tests for speed + fixing position overflows (#26960)

* Shorten the conversation tests for speed + fixing position overflows

* Put max_new_tokens back to 5

* Remove test skips

* Increase max_position_embeddings in blenderbot tests

* Add skips for blenderbot_small

* Correct TF test skip

* make fixup

* Reformat skips to use is_pipeline_test_to_skip

* Update tests/models/blenderbot_small/test_modeling_blenderbot_small.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/blenderbot_small/test_modeling_flax_blenderbot_small.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/blenderbot_small/test_modeling_tf_blenderbot_small.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Matt
2023-10-31 14:20:04 +00:00
committed by GitHub
parent a8e74ebdc5
commit 08fadc8085
8 changed files with 22 additions and 26 deletions

View File

@@ -77,14 +77,14 @@ class ConversationalPipelineTests(unittest.TestCase):
def run_pipeline_test(self, conversation_agent, _):
# Simple
outputs = conversation_agent(Conversation("Hi there!"), max_new_tokens=20)
outputs = conversation_agent(Conversation("Hi there!"), max_new_tokens=5)
self.assertEqual(
outputs,
Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
)
# Single list
outputs = conversation_agent([Conversation("Hi there!")], max_new_tokens=20)
outputs = conversation_agent([Conversation("Hi there!")], max_new_tokens=5)
self.assertEqual(
outputs,
Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
@@ -96,7 +96,7 @@ class ConversationalPipelineTests(unittest.TestCase):
self.assertEqual(len(conversation_1), 1)
self.assertEqual(len(conversation_2), 1)
outputs = conversation_agent([conversation_1, conversation_2], max_new_tokens=20)
outputs = conversation_agent([conversation_1, conversation_2], max_new_tokens=5)
self.assertEqual(outputs, [conversation_1, conversation_2])
self.assertEqual(
outputs,
@@ -118,7 +118,7 @@ class ConversationalPipelineTests(unittest.TestCase):
# One conversation with history
conversation_2.add_message({"role": "user", "content": "Why do you recommend it?"})
outputs = conversation_agent(conversation_2, max_new_tokens=20)
outputs = conversation_agent(conversation_2, max_new_tokens=5)
self.assertEqual(outputs, conversation_2)
self.assertEqual(
outputs,