Adding new encoder_no_repeat_ngram_size to generate. (#9984)
Adding new `encoder_no_repeat_ngram_size` to `generate`. Blenderbot results seemed off compared to original ParlAI script: `https://parl.ai/projects/recipes/`. Notably the model seems to repeat a lot what was said during the conversation. The actual problem was that `no_repeat_ngram_size` actually applies to the `encoder_input_ids` but HF's `no_repeat_ngram_size` applies to the previously generated ids (within the decoder). The history conversation of blenderbot is within the `encoder` part so that explains why HF's implementation had the repetitions. This fix was focused on blenderbot *not* small and added tests for those because they are quite different in configuration. This change includes: - Adding a new EncoderNoRepeatLogitProcessor. - Adding 1 new arg to `generate` (`encoder_no_repeat_ngram_size`) - Adding 1 new config parameter `encoder_no_repeat_ngram_size`. - Adding 2 tests, one for the pipeline (high level, inputs exhibited repeat behavior, one low level for EncoderNoRepeatLogitProcessor) - Factored NoRepeatLogitProcessor so that logic could be reused. Further work: - Blenderbot conversational pipeline still does not behave correctly as they way input is prepared within the pipeline is still incorrect (follow up PR) - Blenderbot allows the bot to have personas, which is done by prepending "your personna: XXXX" to the input, this could be explored too in a follow up PR. @patrickvonplaten @LysandreJik * Update src/transformers/generation_logits_process.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/generation_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/generation_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/configuration_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Doc quality. * Fixing test. * Last fixes. * Fixing to account for batch_size. * Update src/transformers/configuration_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/generation_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -27,6 +27,7 @@ if is_torch_available():
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers.generation_logits_process import (
|
||||
EncoderNoRepeatNGramLogitsProcessor,
|
||||
HammingDiversityLogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
MinLengthLogitsProcessor,
|
||||
@@ -208,6 +209,68 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
torch.isinf(filtered_scores_3_gram).tolist(), [[False, False, False], [True, False, False]]
|
||||
)
|
||||
|
||||
def test_encoder_no_repeat_ngram_dist_processor(self):
|
||||
vocab_size = 3
|
||||
num_beams = 2
|
||||
batch_size = 1
|
||||
|
||||
encoder_input_ids = torch.tensor([1, 2, 1, 1], device=torch_device, dtype=torch.long)
|
||||
|
||||
input_ids = torch.tensor([[1, 2, 1], [8, 0, 2]], device=torch_device, dtype=torch.long)
|
||||
scores = self._get_uniform_logits(batch_size * num_beams, vocab_size)
|
||||
|
||||
no_repeat_proc_2_gram = EncoderNoRepeatNGramLogitsProcessor(2, encoder_input_ids=encoder_input_ids)
|
||||
no_repeat_proc_3_gram = EncoderNoRepeatNGramLogitsProcessor(3, encoder_input_ids=encoder_input_ids)
|
||||
|
||||
filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores.clone())
|
||||
filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores.clone())
|
||||
|
||||
# 2-gram would forbid 1st and 2nd token at 1st beam and 1st token (0) at 2nd beam
|
||||
self.assertListEqual(torch.isinf(filtered_scores_2_gram).tolist(), [[False, True, True], [False, True, False]])
|
||||
|
||||
# 3-gram would forbid 1st token at 1st beam and no token at 2nd beam
|
||||
self.assertListEqual(
|
||||
torch.isinf(filtered_scores_3_gram).tolist(), [[False, True, False], [False, False, False]]
|
||||
)
|
||||
|
||||
# Batched input
|
||||
vocab_size = 3
|
||||
num_beams = 2
|
||||
batch_size = 2
|
||||
encoder_input_ids = torch.tensor([[1, 2, 1, 1], [0, 0, 2, 1]], device=torch_device, dtype=torch.long)
|
||||
|
||||
input_ids = torch.tensor([[1, 2, 1], [1, 0, 2], [0, 0, 0], [0, 2, 2]], device=torch_device, dtype=torch.long)
|
||||
scores = self._get_uniform_logits(batch_size * num_beams, vocab_size)
|
||||
|
||||
no_repeat_proc_2_gram = EncoderNoRepeatNGramLogitsProcessor(2, encoder_input_ids=encoder_input_ids)
|
||||
no_repeat_proc_3_gram = EncoderNoRepeatNGramLogitsProcessor(3, encoder_input_ids=encoder_input_ids)
|
||||
|
||||
filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores.clone())
|
||||
filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores.clone())
|
||||
|
||||
# 2gram
|
||||
# Batch 1
|
||||
# - Beam 1: tokens (1, 2) forbidden
|
||||
# - Beam 2: tokens (1) forbidden
|
||||
# Batch 2
|
||||
# - Beam 1: tokens (0, 2) forbidden
|
||||
# - Beam 2: tokens (1) forbidden
|
||||
self.assertListEqual(
|
||||
torch.isinf(filtered_scores_2_gram).tolist(),
|
||||
[[False, True, True], [False, True, False], [True, False, True], [False, True, False]],
|
||||
)
|
||||
|
||||
# Batch 1
|
||||
# - Beam 1: tokens (1) forbidden
|
||||
# - Beam 2: tokens () forbidden
|
||||
# Batch 2
|
||||
# - Beam 1: tokens (2) forbidden
|
||||
# - Beam 2: tokens () forbidden
|
||||
self.assertListEqual(
|
||||
torch.isinf(filtered_scores_3_gram).tolist(),
|
||||
[[False, True, False], [False, False, False], [False, False, True], [False, False, False]],
|
||||
)
|
||||
|
||||
def test_no_bad_words_dist_processor(self):
|
||||
vocab_size = 5
|
||||
batch_size = 2
|
||||
|
||||
@@ -276,6 +276,47 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
||||
self.assertEqual(result.past_user_inputs[1], "Is it an action movie?")
|
||||
self.assertEqual(result.generated_responses[1], "It's a comedy.")
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_integration_torch_conversation_blenderbot_400M(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-400M-distill")
|
||||
nlp = ConversationalPipeline(model=model, tokenizer=tokenizer)
|
||||
|
||||
conversation_1 = Conversation("hello")
|
||||
result = nlp(
|
||||
conversation_1,
|
||||
)
|
||||
self.assertEqual(
|
||||
result.generated_responses[0],
|
||||
# ParlAI implementation output, we have a different one, but it's our
|
||||
# second best, you can check by using num_return_sequences=10
|
||||
# " Hello! How are you? I'm just getting ready to go to work, how about you?",
|
||||
" Hello! How are you doing today? I just got back from a walk with my dog.",
|
||||
)
|
||||
|
||||
conversation_1 = Conversation(" Lasagne hello")
|
||||
result = nlp(conversation_1, encoder_no_repeat_ngram_size=3)
|
||||
self.assertEqual(
|
||||
result.generated_responses[0],
|
||||
" Lasagne is my favorite Italian dish. Do you like lasagne?",
|
||||
)
|
||||
|
||||
conversation_1 = Conversation(
|
||||
"Lasagne hello Lasagne is my favorite Italian dish. Do you like lasagne? I like lasagne."
|
||||
)
|
||||
result = nlp(
|
||||
conversation_1,
|
||||
encoder_no_repeat_ngram_size=3,
|
||||
)
|
||||
self.assertEqual(
|
||||
result.generated_responses[0],
|
||||
# ParlAI implementation output, we have a different one, but it's our
|
||||
# second best, you can check by using num_return_sequences=10
|
||||
# " Hello! How are you? I'm just getting ready to go to work, how about you?",
|
||||
" Lasagne is a traditional Italian dish consisting of a yeasted flatbread typically topped with tomato sauce and cheese.",
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_integration_torch_conversation_encoder_decoder(self):
|
||||
|
||||
Reference in New Issue
Block a user