Adding new encoder_no_repeat_ngram_size to generate. (#9984)

Adding new `encoder_no_repeat_ngram_size` to `generate`.

Blenderbot results seemed off compared to original ParlAI script:
`https://parl.ai/projects/recipes/`. Notably the model seems
to repeat a lot what was said during the conversation.

The actual problem was that `no_repeat_ngram_size` actually applies
to the `encoder_input_ids` but HF's `no_repeat_ngram_size` applies
to the previously generated ids (within the decoder). The history
conversation of blenderbot is within the `encoder` part so that
explains why HF's implementation had the repetitions.

This fix was focused on blenderbot *not* small and added tests
for those because they are quite different in configuration.

This change includes:

- Adding a new EncoderNoRepeatLogitProcessor.
- Adding 1 new arg to `generate` (`encoder_no_repeat_ngram_size`)
- Adding 1 new config parameter `encoder_no_repeat_ngram_size`.
- Adding 2 tests, one for the pipeline (high level, inputs exhibited
repeat behavior, one low level for EncoderNoRepeatLogitProcessor)
- Factored NoRepeatLogitProcessor so that logic could be reused.

Further work:

- Blenderbot conversational pipeline still does not behave correctly
 as they way input is prepared within the pipeline is still incorrect
(follow up PR)
- Blenderbot allows the bot to have personas, which is done by
prepending "your personna: XXXX" to the input, this could be explored
too in a follow up PR.

@patrickvonplaten
@LysandreJik

* Update src/transformers/generation_logits_process.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/generation_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/generation_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/configuration_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Doc quality.

* Fixing test.

* Last fixes.

* Fixing to account for batch_size.

* Update src/transformers/configuration_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/generation_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Nicolas Patry
2021-02-04 15:00:18 +01:00
committed by GitHub
parent e89c959af9
commit aeb18b9224
6 changed files with 209 additions and 22 deletions

View File

@@ -23,6 +23,7 @@ from torch.nn import functional as F
from .file_utils import ModelOutput
from .generation_beam_search import BeamScorer, BeamSearchScorer
from .generation_logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
HammingDiversityLogitsProcessor,
LogitsProcessorList,
MinLengthLogitsProcessor,
@@ -537,6 +538,8 @@ class GenerationMixin:
self,
repetition_penalty: float,
no_repeat_ngram_size: int,
encoder_no_repeat_ngram_size: int,
encoder_input_ids: torch.LongTensor,
bad_words_ids: List[List[int]],
min_length: int,
eos_token_id: int,
@@ -555,6 +558,11 @@ class GenerationMixin:
no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
)
encoder_no_repeat_ngram_size = (
encoder_no_repeat_ngram_size
if encoder_no_repeat_ngram_size is not None
else self.config.encoder_no_repeat_ngram_size
)
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
min_length = min_length if min_length is not None else self.config.min_length
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
@@ -574,6 +582,13 @@ class GenerationMixin:
processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0:
if self.config.is_encoder_decoder:
processors.append(EncoderNoRepeatNGramLogitsProcessor(encoder_no_repeat_ngram_size, encoder_input_ids))
else:
raise ValueError(
"It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture"
)
if bad_words_ids is not None:
processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
if min_length is not None and eos_token_id is not None and min_length > -1:
@@ -601,6 +616,7 @@ class GenerationMixin:
eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
encoder_no_repeat_ngram_size: Optional[int] = None,
num_return_sequences: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
@@ -661,6 +677,9 @@ class GenerationMixin:
sequences.
no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
If set to int > 0, all ngrams of that size can only occur once.
encoder_no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
If set to int > 0, all ngrams of that size that occur in the ``encoder_input_ids`` cannot occur in the
``decoder_input_ids``.
bad_words_ids(:obj:`List[List[int]]`, `optional`):
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
should not appear in the generated text, use :obj:`tokenizer(bad_word,
@@ -820,6 +839,9 @@ class GenerationMixin:
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id
# Storing encoder_input_ids for logits_processor that could use them
encoder_input_ids = input_ids if self.config.is_encoder_decoder else None
if self.config.is_encoder_decoder:
# add encoder_outputs to model_kwargs
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
@@ -862,6 +884,8 @@ class GenerationMixin:
logits_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
encoder_input_ids=encoder_input_ids,
bad_words_ids=bad_words_ids,
min_length=min_length,
eos_token_id=eos_token_id,
@@ -1638,6 +1662,7 @@ class GenerationMixin:
beam_idx = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
cur_len = cur_len + 1
model_kwargs = self._update_model_kwargs_for_generation(