Adding new encoder_no_repeat_ngram_size to generate. (#9984)
Adding new `encoder_no_repeat_ngram_size` to `generate`. Blenderbot results seemed off compared to original ParlAI script: `https://parl.ai/projects/recipes/`. Notably the model seems to repeat a lot what was said during the conversation. The actual problem was that `no_repeat_ngram_size` actually applies to the `encoder_input_ids` but HF's `no_repeat_ngram_size` applies to the previously generated ids (within the decoder). The history conversation of blenderbot is within the `encoder` part so that explains why HF's implementation had the repetitions. This fix was focused on blenderbot *not* small and added tests for those because they are quite different in configuration. This change includes: - Adding a new EncoderNoRepeatLogitProcessor. - Adding 1 new arg to `generate` (`encoder_no_repeat_ngram_size`) - Adding 1 new config parameter `encoder_no_repeat_ngram_size`. - Adding 2 tests, one for the pipeline (high level, inputs exhibited repeat behavior, one low level for EncoderNoRepeatLogitProcessor) - Factored NoRepeatLogitProcessor so that logic could be reused. Further work: - Blenderbot conversational pipeline still does not behave correctly as they way input is prepared within the pipeline is still incorrect (follow up PR) - Blenderbot allows the bot to have personas, which is done by prepending "your personna: XXXX" to the input, this could be explored too in a follow up PR. @patrickvonplaten @LysandreJik * Update src/transformers/generation_logits_process.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/generation_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/generation_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/configuration_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Doc quality. * Fixing test. * Last fixes. * Fixing to account for batch_size. * Update src/transformers/configuration_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/generation_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -23,6 +23,7 @@ from torch.nn import functional as F
|
||||
from .file_utils import ModelOutput
|
||||
from .generation_beam_search import BeamScorer, BeamSearchScorer
|
||||
from .generation_logits_process import (
|
||||
EncoderNoRepeatNGramLogitsProcessor,
|
||||
HammingDiversityLogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
MinLengthLogitsProcessor,
|
||||
@@ -537,6 +538,8 @@ class GenerationMixin:
|
||||
self,
|
||||
repetition_penalty: float,
|
||||
no_repeat_ngram_size: int,
|
||||
encoder_no_repeat_ngram_size: int,
|
||||
encoder_input_ids: torch.LongTensor,
|
||||
bad_words_ids: List[List[int]],
|
||||
min_length: int,
|
||||
eos_token_id: int,
|
||||
@@ -555,6 +558,11 @@ class GenerationMixin:
|
||||
no_repeat_ngram_size = (
|
||||
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
|
||||
)
|
||||
encoder_no_repeat_ngram_size = (
|
||||
encoder_no_repeat_ngram_size
|
||||
if encoder_no_repeat_ngram_size is not None
|
||||
else self.config.encoder_no_repeat_ngram_size
|
||||
)
|
||||
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
|
||||
min_length = min_length if min_length is not None else self.config.min_length
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
@@ -574,6 +582,13 @@ class GenerationMixin:
|
||||
processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
||||
if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
|
||||
processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
|
||||
if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0:
|
||||
if self.config.is_encoder_decoder:
|
||||
processors.append(EncoderNoRepeatNGramLogitsProcessor(encoder_no_repeat_ngram_size, encoder_input_ids))
|
||||
else:
|
||||
raise ValueError(
|
||||
"It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture"
|
||||
)
|
||||
if bad_words_ids is not None:
|
||||
processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
|
||||
if min_length is not None and eos_token_id is not None and min_length > -1:
|
||||
@@ -601,6 +616,7 @@ class GenerationMixin:
|
||||
eos_token_id: Optional[int] = None,
|
||||
length_penalty: Optional[float] = None,
|
||||
no_repeat_ngram_size: Optional[int] = None,
|
||||
encoder_no_repeat_ngram_size: Optional[int] = None,
|
||||
num_return_sequences: Optional[int] = None,
|
||||
decoder_start_token_id: Optional[int] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
@@ -661,6 +677,9 @@ class GenerationMixin:
|
||||
sequences.
|
||||
no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
|
||||
If set to int > 0, all ngrams of that size can only occur once.
|
||||
encoder_no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
|
||||
If set to int > 0, all ngrams of that size that occur in the ``encoder_input_ids`` cannot occur in the
|
||||
``decoder_input_ids``.
|
||||
bad_words_ids(:obj:`List[List[int]]`, `optional`):
|
||||
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
|
||||
should not appear in the generated text, use :obj:`tokenizer(bad_word,
|
||||
@@ -820,6 +839,9 @@ class GenerationMixin:
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||
pad_token_id = eos_token_id
|
||||
|
||||
# Storing encoder_input_ids for logits_processor that could use them
|
||||
encoder_input_ids = input_ids if self.config.is_encoder_decoder else None
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
# add encoder_outputs to model_kwargs
|
||||
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
|
||||
@@ -862,6 +884,8 @@ class GenerationMixin:
|
||||
logits_processor = self._get_logits_processor(
|
||||
repetition_penalty=repetition_penalty,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
|
||||
encoder_input_ids=encoder_input_ids,
|
||||
bad_words_ids=bad_words_ids,
|
||||
min_length=min_length,
|
||||
eos_token_id=eos_token_id,
|
||||
@@ -1638,6 +1662,7 @@ class GenerationMixin:
|
||||
beam_idx = beam_outputs["next_beam_indices"]
|
||||
|
||||
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
||||
|
||||
cur_len = cur_len + 1
|
||||
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
|
||||
Reference in New Issue
Block a user