Fixed: Better names for nlp variables in pipelines' tests and docs. (#11752)
* Fixed: Better names for nlp variables in pipelines' tests and docs. * Fixed: Better variable names
This commit is contained in:
@@ -128,41 +128,41 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
||||
invalid_inputs = ["Hi there!", Conversation()]
|
||||
|
||||
def _test_pipeline(
|
||||
self, nlp
|
||||
self, conversation_agent
|
||||
): # override the default test method to check that the output is a `Conversation` object
|
||||
self.assertIsNotNone(nlp)
|
||||
self.assertIsNotNone(conversation_agent)
|
||||
|
||||
# We need to recreate conversation for successive tests to pass as
|
||||
# Conversation objects get *consumed* by the pipeline
|
||||
conversation = Conversation("Hi there!")
|
||||
mono_result = nlp(conversation)
|
||||
mono_result = conversation_agent(conversation)
|
||||
self.assertIsInstance(mono_result, Conversation)
|
||||
|
||||
conversations = [Conversation("Hi there!"), Conversation("How are you?")]
|
||||
multi_result = nlp(conversations)
|
||||
multi_result = conversation_agent(conversations)
|
||||
self.assertIsInstance(multi_result, list)
|
||||
self.assertIsInstance(multi_result[0], Conversation)
|
||||
# Conversation have been consumed and are not valid anymore
|
||||
# Inactive conversations passed to the pipeline raise a ValueError
|
||||
self.assertRaises(ValueError, nlp, conversation)
|
||||
self.assertRaises(ValueError, nlp, conversations)
|
||||
self.assertRaises(ValueError, conversation_agent, conversation)
|
||||
self.assertRaises(ValueError, conversation_agent, conversations)
|
||||
|
||||
for bad_input in self.invalid_inputs:
|
||||
self.assertRaises(Exception, nlp, bad_input)
|
||||
self.assertRaises(Exception, nlp, self.invalid_inputs)
|
||||
self.assertRaises(Exception, conversation_agent, bad_input)
|
||||
self.assertRaises(Exception, conversation_agent, self.invalid_inputs)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_integration_torch_conversation(self):
|
||||
# When
|
||||
nlp = pipeline(task="conversational", device=DEFAULT_DEVICE_NUM)
|
||||
conversation_agent = pipeline(task="conversational", device=DEFAULT_DEVICE_NUM)
|
||||
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
||||
conversation_2 = Conversation("What's the last book you have read?")
|
||||
# Then
|
||||
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||
self.assertEqual(len(conversation_2.past_user_inputs), 0)
|
||||
# When
|
||||
result = nlp([conversation_1, conversation_2], do_sample=False, max_length=1000)
|
||||
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
|
||||
# Then
|
||||
self.assertEqual(result, [conversation_1, conversation_2])
|
||||
self.assertEqual(len(result[0].past_user_inputs), 1)
|
||||
@@ -175,7 +175,7 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
||||
self.assertEqual(result[1].generated_responses[0], "The Last Question")
|
||||
# When
|
||||
conversation_2.add_user_input("Why do you recommend it?")
|
||||
result = nlp(conversation_2, do_sample=False, max_length=1000)
|
||||
result = conversation_agent(conversation_2, do_sample=False, max_length=1000)
|
||||
# Then
|
||||
self.assertEqual(result, conversation_2)
|
||||
self.assertEqual(len(result.past_user_inputs), 2)
|
||||
@@ -187,12 +187,12 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
||||
@slow
|
||||
def test_integration_torch_conversation_truncated_history(self):
|
||||
# When
|
||||
nlp = pipeline(task="conversational", min_length_for_response=24, device=DEFAULT_DEVICE_NUM)
|
||||
conversation_agent = pipeline(task="conversational", min_length_for_response=24, device=DEFAULT_DEVICE_NUM)
|
||||
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
||||
# Then
|
||||
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||
# When
|
||||
result = nlp(conversation_1, do_sample=False, max_length=36)
|
||||
result = conversation_agent(conversation_1, do_sample=False, max_length=36)
|
||||
# Then
|
||||
self.assertEqual(result, conversation_1)
|
||||
self.assertEqual(len(result.past_user_inputs), 1)
|
||||
@@ -201,7 +201,7 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
||||
self.assertEqual(result.generated_responses[0], "The Big Lebowski")
|
||||
# When
|
||||
conversation_1.add_user_input("Is it an action movie?")
|
||||
result = nlp(conversation_1, do_sample=False, max_length=36)
|
||||
result = conversation_agent(conversation_1, do_sample=False, max_length=36)
|
||||
# Then
|
||||
self.assertEqual(result, conversation_1)
|
||||
self.assertEqual(len(result.past_user_inputs), 2)
|
||||
@@ -214,19 +214,19 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
||||
def test_integration_torch_conversation_dialogpt_input_ids(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
|
||||
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
|
||||
nlp = ConversationalPipeline(model=model, tokenizer=tokenizer)
|
||||
conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer)
|
||||
|
||||
conversation_1 = Conversation("hello")
|
||||
inputs = nlp._parse_and_tokenize([conversation_1])
|
||||
inputs = conversation_agent._parse_and_tokenize([conversation_1])
|
||||
self.assertEqual(inputs["input_ids"].tolist(), [[31373, 50256]])
|
||||
|
||||
conversation_2 = Conversation("how are you ?", past_user_inputs=["hello"], generated_responses=["Hi there!"])
|
||||
inputs = nlp._parse_and_tokenize([conversation_2])
|
||||
inputs = conversation_agent._parse_and_tokenize([conversation_2])
|
||||
self.assertEqual(
|
||||
inputs["input_ids"].tolist(), [[31373, 50256, 17250, 612, 0, 50256, 4919, 389, 345, 5633, 50256]]
|
||||
)
|
||||
|
||||
inputs = nlp._parse_and_tokenize([conversation_1, conversation_2])
|
||||
inputs = conversation_agent._parse_and_tokenize([conversation_1, conversation_2])
|
||||
self.assertEqual(
|
||||
inputs["input_ids"].tolist(),
|
||||
[
|
||||
@@ -240,11 +240,11 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
||||
def test_integration_torch_conversation_blenderbot_400M_input_ids(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-400M-distill")
|
||||
nlp = ConversationalPipeline(model=model, tokenizer=tokenizer)
|
||||
conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer)
|
||||
|
||||
# test1
|
||||
conversation_1 = Conversation("hello")
|
||||
inputs = nlp._parse_and_tokenize([conversation_1])
|
||||
inputs = conversation_agent._parse_and_tokenize([conversation_1])
|
||||
self.assertEqual(inputs["input_ids"].tolist(), [[1710, 86, 2]])
|
||||
|
||||
# test2
|
||||
@@ -255,7 +255,7 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
||||
" Do you like lasagne? It is a traditional Italian dish consisting of a shepherd's pie."
|
||||
],
|
||||
)
|
||||
inputs = nlp._parse_and_tokenize([conversation_1])
|
||||
inputs = conversation_agent._parse_and_tokenize([conversation_1])
|
||||
self.assertEqual(
|
||||
inputs["input_ids"].tolist(),
|
||||
[
|
||||
@@ -310,10 +310,10 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
||||
def test_integration_torch_conversation_blenderbot_400M(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-400M-distill")
|
||||
nlp = ConversationalPipeline(model=model, tokenizer=tokenizer)
|
||||
conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer)
|
||||
|
||||
conversation_1 = Conversation("hello")
|
||||
result = nlp(
|
||||
result = conversation_agent(
|
||||
conversation_1,
|
||||
)
|
||||
self.assertEqual(
|
||||
@@ -325,7 +325,7 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
||||
)
|
||||
|
||||
conversation_1 = Conversation("Lasagne hello")
|
||||
result = nlp(conversation_1, encoder_no_repeat_ngram_size=3)
|
||||
result = conversation_agent(conversation_1, encoder_no_repeat_ngram_size=3)
|
||||
self.assertEqual(
|
||||
result.generated_responses[0],
|
||||
" Do you like lasagne? It is a traditional Italian dish consisting of a shepherd's pie.",
|
||||
@@ -334,7 +334,7 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
||||
conversation_1 = Conversation(
|
||||
"Lasagne hello Lasagne is my favorite Italian dish. Do you like lasagne? I like lasagne."
|
||||
)
|
||||
result = nlp(
|
||||
result = conversation_agent(
|
||||
conversation_1,
|
||||
encoder_no_repeat_ngram_size=3,
|
||||
)
|
||||
@@ -349,7 +349,7 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
||||
# When
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot_small-90M")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot_small-90M")
|
||||
nlp = ConversationalPipeline(model=model, tokenizer=tokenizer, device=DEFAULT_DEVICE_NUM)
|
||||
conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer, device=DEFAULT_DEVICE_NUM)
|
||||
|
||||
conversation_1 = Conversation("My name is Sarah and I live in London")
|
||||
conversation_2 = Conversation("Going to the movies tonight, What movie would you recommend? ")
|
||||
@@ -357,7 +357,7 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
||||
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||
self.assertEqual(len(conversation_2.past_user_inputs), 0)
|
||||
# When
|
||||
result = nlp([conversation_1, conversation_2], do_sample=False, max_length=1000)
|
||||
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
|
||||
# Then
|
||||
self.assertEqual(result, [conversation_1, conversation_2])
|
||||
self.assertEqual(len(result[0].past_user_inputs), 1)
|
||||
@@ -378,7 +378,7 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
||||
# When
|
||||
conversation_1.add_user_input("Not yet, what about you?")
|
||||
conversation_2.add_user_input("What's your name?")
|
||||
result = nlp([conversation_1, conversation_2], do_sample=False, max_length=1000)
|
||||
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
|
||||
# Then
|
||||
self.assertEqual(result, [conversation_1, conversation_2])
|
||||
self.assertEqual(len(result[0].past_user_inputs), 2)
|
||||
|
||||
Reference in New Issue
Block a user