Shorten the conversation tests for speed + fixing position overflows (#26960)
* Shorten the conversation tests for speed + fixing position overflows * Put max_new_tokens back to 5 * Remove test skips * Increase max_position_embeddings in blenderbot tests * Add skips for blenderbot_small * Correct TF test skip * make fixup * Reformat skips to use is_pipeline_test_to_skip * Update tests/models/blenderbot_small/test_modeling_blenderbot_small.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/blenderbot_small/test_modeling_flax_blenderbot_small.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/blenderbot_small/test_modeling_tf_blenderbot_small.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -85,7 +85,7 @@ class BlenderbotModelTester:
|
|||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
hidden_dropout_prob=0.1,
|
hidden_dropout_prob=0.1,
|
||||||
attention_probs_dropout_prob=0.1,
|
attention_probs_dropout_prob=0.1,
|
||||||
max_position_embeddings=20,
|
max_position_embeddings=50,
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
pad_token_id=1,
|
pad_token_id=1,
|
||||||
bos_token_id=0,
|
bos_token_id=0,
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ class FlaxBlenderbotModelTester:
|
|||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
hidden_dropout_prob=0.1,
|
hidden_dropout_prob=0.1,
|
||||||
attention_probs_dropout_prob=0.1,
|
attention_probs_dropout_prob=0.1,
|
||||||
max_position_embeddings=32,
|
max_position_embeddings=50,
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
pad_token_id=1,
|
pad_token_id=1,
|
||||||
bos_token_id=0,
|
bos_token_id=0,
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ class TFBlenderbotModelTester:
|
|||||||
intermediate_size=37,
|
intermediate_size=37,
|
||||||
hidden_dropout_prob=0.1,
|
hidden_dropout_prob=0.1,
|
||||||
attention_probs_dropout_prob=0.1,
|
attention_probs_dropout_prob=0.1,
|
||||||
max_position_embeddings=20,
|
max_position_embeddings=50,
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
pad_token_id=1,
|
pad_token_id=1,
|
||||||
bos_token_id=0,
|
bos_token_id=0,
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ class BlenderbotSmallModelTester:
|
|||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
hidden_dropout_prob=0.1,
|
hidden_dropout_prob=0.1,
|
||||||
attention_probs_dropout_prob=0.1,
|
attention_probs_dropout_prob=0.1,
|
||||||
max_position_embeddings=20,
|
max_position_embeddings=50,
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
pad_token_id=1,
|
pad_token_id=1,
|
||||||
bos_token_id=0,
|
bos_token_id=0,
|
||||||
@@ -242,12 +242,7 @@ class BlenderbotSmallModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
|
|||||||
def is_pipeline_test_to_skip(
|
def is_pipeline_test_to_skip(
|
||||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||||
):
|
):
|
||||||
if pipeline_test_casse_name == "TextGenerationPipelineTests":
|
return pipeline_test_casse_name in ("TextGenerationPipelineTests", "ConversationalPipelineTests")
|
||||||
return True
|
|
||||||
# TODO @Rocketnight1 to fix
|
|
||||||
if pipeline_test_casse_name == "ConversationalPipelineTests":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = BlenderbotSmallModelTester(self)
|
self.model_tester = BlenderbotSmallModelTester(self)
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ class FlaxBlenderbotSmallModelTester:
|
|||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
hidden_dropout_prob=0.1,
|
hidden_dropout_prob=0.1,
|
||||||
attention_probs_dropout_prob=0.1,
|
attention_probs_dropout_prob=0.1,
|
||||||
max_position_embeddings=32,
|
max_position_embeddings=50,
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
pad_token_id=1,
|
pad_token_id=1,
|
||||||
bos_token_id=0,
|
bos_token_id=0,
|
||||||
@@ -320,6 +320,11 @@ class FlaxBlenderbotSmallModelTest(FlaxModelTesterMixin, unittest.TestCase, Flax
|
|||||||
)
|
)
|
||||||
all_generative_model_classes = (FlaxBlenderbotSmallForConditionalGeneration,) if is_flax_available() else ()
|
all_generative_model_classes = (FlaxBlenderbotSmallForConditionalGeneration,) if is_flax_available() else ()
|
||||||
|
|
||||||
|
def is_pipeline_test_to_skip(
|
||||||
|
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||||
|
):
|
||||||
|
return pipeline_test_casse_name in ("TextGenerationPipelineTests", "ConversationalPipelineTests")
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = FlaxBlenderbotSmallModelTester(self)
|
self.model_tester = FlaxBlenderbotSmallModelTester(self)
|
||||||
|
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ class TFBlenderbotSmallModelTester:
|
|||||||
intermediate_size=37,
|
intermediate_size=37,
|
||||||
hidden_dropout_prob=0.1,
|
hidden_dropout_prob=0.1,
|
||||||
attention_probs_dropout_prob=0.1,
|
attention_probs_dropout_prob=0.1,
|
||||||
max_position_embeddings=20,
|
max_position_embeddings=50,
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
pad_token_id=1,
|
pad_token_id=1,
|
||||||
bos_token_id=0,
|
bos_token_id=0,
|
||||||
@@ -198,6 +198,11 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, PipelineTesterMixin, unitte
|
|||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_onnx = False
|
test_onnx = False
|
||||||
|
|
||||||
|
def is_pipeline_test_to_skip(
|
||||||
|
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||||
|
):
|
||||||
|
return pipeline_test_casse_name in ("TextGenerationPipelineTests", "ConversationalPipelineTests")
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFBlenderbotSmallModelTester(self)
|
self.model_tester = TFBlenderbotSmallModelTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=BlenderbotSmallConfig)
|
self.config_tester = ConfigTester(self, config_class=BlenderbotSmallConfig)
|
||||||
@@ -209,15 +214,6 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, PipelineTesterMixin, unitte
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)
|
self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||||
|
|
||||||
# TODO: Fix the failed tests when this model gets more usage
|
|
||||||
def is_pipeline_test_to_skip(
|
|
||||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
|
||||||
):
|
|
||||||
# TODO @Rocketnight1 to fix
|
|
||||||
if pipeline_test_casse_name == "ConversationalPipelineTests":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
@require_tf
|
@require_tf
|
||||||
|
|||||||
@@ -77,14 +77,14 @@ class ConversationalPipelineTests(unittest.TestCase):
|
|||||||
|
|
||||||
def run_pipeline_test(self, conversation_agent, _):
|
def run_pipeline_test(self, conversation_agent, _):
|
||||||
# Simple
|
# Simple
|
||||||
outputs = conversation_agent(Conversation("Hi there!"), max_new_tokens=20)
|
outputs = conversation_agent(Conversation("Hi there!"), max_new_tokens=5)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
outputs,
|
outputs,
|
||||||
Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
|
Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Single list
|
# Single list
|
||||||
outputs = conversation_agent([Conversation("Hi there!")], max_new_tokens=20)
|
outputs = conversation_agent([Conversation("Hi there!")], max_new_tokens=5)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
outputs,
|
outputs,
|
||||||
Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
|
Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
|
||||||
@@ -96,7 +96,7 @@ class ConversationalPipelineTests(unittest.TestCase):
|
|||||||
self.assertEqual(len(conversation_1), 1)
|
self.assertEqual(len(conversation_1), 1)
|
||||||
self.assertEqual(len(conversation_2), 1)
|
self.assertEqual(len(conversation_2), 1)
|
||||||
|
|
||||||
outputs = conversation_agent([conversation_1, conversation_2], max_new_tokens=20)
|
outputs = conversation_agent([conversation_1, conversation_2], max_new_tokens=5)
|
||||||
self.assertEqual(outputs, [conversation_1, conversation_2])
|
self.assertEqual(outputs, [conversation_1, conversation_2])
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
outputs,
|
outputs,
|
||||||
@@ -118,7 +118,7 @@ class ConversationalPipelineTests(unittest.TestCase):
|
|||||||
|
|
||||||
# One conversation with history
|
# One conversation with history
|
||||||
conversation_2.add_message({"role": "user", "content": "Why do you recommend it?"})
|
conversation_2.add_message({"role": "user", "content": "Why do you recommend it?"})
|
||||||
outputs = conversation_agent(conversation_2, max_new_tokens=20)
|
outputs = conversation_agent(conversation_2, max_new_tokens=5)
|
||||||
self.assertEqual(outputs, conversation_2)
|
self.assertEqual(outputs, conversation_2)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
outputs,
|
outputs,
|
||||||
|
|||||||
@@ -313,7 +313,7 @@ class PipelineTesterMixin:
|
|||||||
|
|
||||||
out = []
|
out = []
|
||||||
if task == "conversational":
|
if task == "conversational":
|
||||||
for item in pipeline(data(10), batch_size=4, max_new_tokens=20):
|
for item in pipeline(data(10), batch_size=4, max_new_tokens=5):
|
||||||
out.append(item)
|
out.append(item)
|
||||||
else:
|
else:
|
||||||
for item in pipeline(data(10), batch_size=4):
|
for item in pipeline(data(10), batch_size=4):
|
||||||
|
|||||||
Reference in New Issue
Block a user