Making Conversation possible to create directly a full conversation (#9434)
* Cleaning up conversation tests. * Adding tests that don't require downloading models + conversation can be fully created from static state. * Making tests non flaky (by fixing generation length) * Bumping isort version. * Doc cleanup. * Remove unused test in this PR. * Torch import guard for TF. * Missing torch guard. * Small mistake in doc. * Actual uses `_history` and `_index` cache. + remove dead enumerate + improve warning message. * Update src/transformers/pipelines/conversational.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/pipelines/conversational.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/pipelines/conversational.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Adding comments and cleaner code to address history copy. * Improving pipeline name in tests. * Change tokenizer to a real one (still created at runtime with no external dependency) * Simplify DummyTok, reverse changes on tokenization. * Removing DummyTok. Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -14,15 +14,177 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Conversation, ConversationalPipeline, pipeline
|
||||
from transformers import (
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
Conversation,
|
||||
ConversationalPipeline,
|
||||
is_torch_available,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_pipelines_common import MonoInputPipelineCommonMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.models.gpt2 import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
|
||||
|
||||
|
||||
class SimpleConversationPipelineTests(unittest.TestCase):
|
||||
def get_pipeline(self):
|
||||
# When
|
||||
config = GPT2Config(
|
||||
vocab_size=263,
|
||||
n_ctx=128,
|
||||
max_length=128,
|
||||
n_embd=64,
|
||||
n_layer=1,
|
||||
n_head=8,
|
||||
bos_token_id=256,
|
||||
eos_token_id=257,
|
||||
)
|
||||
model = GPT2LMHeadModel(config)
|
||||
# Force model output to be L
|
||||
V, D = model.lm_head.weight.shape
|
||||
bias = torch.zeros(V, requires_grad=True)
|
||||
bias[76] = 1
|
||||
|
||||
model.lm_head.bias = torch.nn.Parameter(bias)
|
||||
|
||||
# # Created with:
|
||||
# import tempfile
|
||||
|
||||
# from tokenizers import Tokenizer, models
|
||||
# from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
|
||||
# vocab = [(chr(i), i) for i in range(256)]
|
||||
# tokenizer = Tokenizer(models.Unigram(vocab))
|
||||
# with tempfile.NamedTemporaryFile() as f:
|
||||
# tokenizer.save(f.name)
|
||||
# real_tokenizer = PreTrainedTokenizerFast(tokenizer_file=f.name, eos_token="<eos>", bos_token="<bos>")
|
||||
|
||||
# real_tokenizer._tokenizer.save("dummy.json")
|
||||
# Special tokens are automatically added at load time.
|
||||
tokenizer = AutoTokenizer.from_pretrained("Narsil/small_conversational_test")
|
||||
conversation_agent = pipeline(
|
||||
task="conversational", device=DEFAULT_DEVICE_NUM, model=model, tokenizer=tokenizer
|
||||
)
|
||||
return conversation_agent
|
||||
|
||||
@require_torch
|
||||
def test_integration_torch_conversation(self):
|
||||
conversation_agent = self.get_pipeline()
|
||||
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
||||
conversation_2 = Conversation("What's the last book you have read?")
|
||||
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||
self.assertEqual(len(conversation_2.past_user_inputs), 0)
|
||||
|
||||
with self.assertLogs("transformers", level="WARNING") as log:
|
||||
result = conversation_agent([conversation_1, conversation_2], max_length=48)
|
||||
self.assertEqual(len(log.output), 2)
|
||||
self.assertIn("You might consider trimming the early phase of the conversation", log.output[0])
|
||||
self.assertIn("Setting `pad_token_id`", log.output[1])
|
||||
|
||||
# Two conversations in one pass
|
||||
self.assertEqual(result, [conversation_1, conversation_2])
|
||||
self.assertEqual(
|
||||
result,
|
||||
[
|
||||
Conversation(
|
||||
None,
|
||||
past_user_inputs=["Going to the movies tonight - any suggestions?"],
|
||||
generated_responses=["L"],
|
||||
),
|
||||
Conversation(
|
||||
None, past_user_inputs=["What's the last book you have read?"], generated_responses=["L"]
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# One conversation with history
|
||||
conversation_2.add_user_input("Why do you recommend it?")
|
||||
with self.assertLogs("transformers", level="WARNING") as log:
|
||||
result = conversation_agent(conversation_2, max_length=64)
|
||||
self.assertEqual(len(log.output), 3)
|
||||
self.assertIn("Cutting history off because it's too long", log.output[0])
|
||||
self.assertIn("You might consider trimming the early phase of the conversation", log.output[1])
|
||||
self.assertIn("Setting `pad_token_id`", log.output[2])
|
||||
|
||||
self.assertEqual(result, conversation_2)
|
||||
self.assertEqual(
|
||||
result,
|
||||
Conversation(
|
||||
None,
|
||||
past_user_inputs=["What's the last book you have read?", "Why do you recommend it?"],
|
||||
generated_responses=["L", "L"],
|
||||
),
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_history_cache(self):
|
||||
conversation_agent = self.get_pipeline()
|
||||
conversation = Conversation(
|
||||
"Why do you recommend it?",
|
||||
past_user_inputs=["What's the last book you have read?"],
|
||||
generated_responses=["b"],
|
||||
)
|
||||
with self.assertLogs("transformers", level="WARNING") as log:
|
||||
_ = conversation_agent(conversation, max_length=60)
|
||||
self.assertEqual(len(log.output), 3)
|
||||
self.assertIn("Cutting history off because it's too long (63 > 28) for underlying model", log.output[0])
|
||||
self.assertIn("63 is bigger than 0.9 * max_length: 60", log.output[1])
|
||||
self.assertIn("Setting `pad_token_id`", log.output[2])
|
||||
self.assertEqual(conversation._index, 1)
|
||||
self.assertEqual(
|
||||
conversation._history,
|
||||
[
|
||||
87,
|
||||
104,
|
||||
97,
|
||||
116,
|
||||
39,
|
||||
115,
|
||||
32,
|
||||
116,
|
||||
104,
|
||||
101,
|
||||
32,
|
||||
108,
|
||||
97,
|
||||
115,
|
||||
116,
|
||||
32,
|
||||
98,
|
||||
111,
|
||||
111,
|
||||
107,
|
||||
32,
|
||||
121,
|
||||
111,
|
||||
117,
|
||||
32,
|
||||
104,
|
||||
97,
|
||||
118,
|
||||
101,
|
||||
32,
|
||||
114,
|
||||
101,
|
||||
97,
|
||||
100,
|
||||
63,
|
||||
259, # EOS
|
||||
98, # b
|
||||
259, # EOS
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
pipeline_task = "conversational"
|
||||
small_models = [] # Models tested without the @slow decorator
|
||||
|
||||
Reference in New Issue
Block a user