Adding new encoder_no_repeat_ngram_size to generate. (#9984)
Adding new `encoder_no_repeat_ngram_size` to `generate`. Blenderbot results seemed off compared to original ParlAI script: `https://parl.ai/projects/recipes/`. Notably the model seems to repeat a lot what was said during the conversation. The actual problem was that `no_repeat_ngram_size` actually applies to the `encoder_input_ids` but HF's `no_repeat_ngram_size` applies to the previously generated ids (within the decoder). The history conversation of blenderbot is within the `encoder` part so that explains why HF's implementation had the repetitions. This fix was focused on blenderbot *not* small and added tests for those because they are quite different in configuration. This change includes: - Adding a new EncoderNoRepeatLogitProcessor. - Adding 1 new arg to `generate` (`encoder_no_repeat_ngram_size`) - Adding 1 new config parameter `encoder_no_repeat_ngram_size`. - Adding 2 tests, one for the pipeline (high level, inputs exhibited repeat behavior, one low level for EncoderNoRepeatLogitProcessor) - Factored NoRepeatLogitProcessor so that logic could be reused. Further work: - Blenderbot conversational pipeline still does not behave correctly as they way input is prepared within the pipeline is still incorrect (follow up PR) - Blenderbot allows the bot to have personas, which is done by prepending "your personna: XXXX" to the input, this could be explored too in a follow up PR. @patrickvonplaten @LysandreJik * Update src/transformers/generation_logits_process.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/generation_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/generation_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/configuration_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Doc quality. * Fixing test. * Last fixes. * Fixing to account for batch_size. * Update src/transformers/configuration_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/generation_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -117,6 +117,9 @@ class PretrainedConfig(object):
|
||||
- **no_repeat_ngram_size** (:obj:`int`, `optional`, defaults to 0) -- Value that will be used by default in the
|
||||
:obj:`generate` method of the model for ``no_repeat_ngram_size``. If set to int > 0, all ngrams of that size
|
||||
can only occur once.
|
||||
- **encoder_no_repeat_ngram_size** (:obj:`int`, `optional`, defaults to 0) -- Value that will be used by
|
||||
default in the :obj:`generate` method of the model for ``encoder_no_repeat_ngram_size``. If set to int > 0,
|
||||
all ngrams of that size that occur in the ``encoder_input_ids`` cannot occur in the ``decoder_input_ids``.
|
||||
- **bad_words_ids** (:obj:`List[int]`, `optional`) -- List of token ids that are not allowed to be generated
|
||||
that will be used by default in the :obj:`generate` method of the model. In order to get the tokens of the
|
||||
words that should not appear in the generated text, use :obj:`tokenizer.encode(bad_word,
|
||||
@@ -205,6 +208,7 @@ class PretrainedConfig(object):
|
||||
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
|
||||
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
||||
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
|
||||
self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0)
|
||||
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
|
||||
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
||||
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
|
||||
|
||||
@@ -235,6 +235,41 @@ class TopKLogitsWarper(LogitsWarper):
|
||||
return scores
|
||||
|
||||
|
||||
def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
|
||||
generated_ngrams = [{} for _ in range(num_hypos)]
|
||||
for idx in range(num_hypos):
|
||||
gen_tokens = prev_input_ids[idx].tolist()
|
||||
generated_ngram = generated_ngrams[idx]
|
||||
for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
|
||||
prev_ngram_tuple = tuple(ngram[:-1])
|
||||
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
|
||||
return generated_ngrams
|
||||
|
||||
|
||||
def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
|
||||
# Before decoding the next token, prevent decoding of ngrams that have already appeared
|
||||
start_idx = cur_len + 1 - ngram_size
|
||||
ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
|
||||
return banned_ngrams.get(ngram_idx, [])
|
||||
|
||||
|
||||
def _calc_banned_ngram_tokens(
|
||||
ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int
|
||||
) -> List[Iterable[int]]:
|
||||
"""Copied from fairseq for no_repeat_ngram in beam_search"""
|
||||
if cur_len + 1 < ngram_size:
|
||||
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
||||
return [[] for _ in range(num_hypos)]
|
||||
|
||||
generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)
|
||||
|
||||
banned_tokens = [
|
||||
_get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
|
||||
for hypo_idx in range(num_hypos)
|
||||
]
|
||||
return banned_tokens
|
||||
|
||||
|
||||
class NoRepeatNGramLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
:class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq
|
||||
@@ -253,36 +288,53 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor):
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
num_batch_hypotheses = scores.shape[0]
|
||||
cur_len = input_ids.shape[-1]
|
||||
banned_batch_tokens = self._calc_banned_ngram_tokens(input_ids, num_batch_hypotheses, cur_len)
|
||||
banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len)
|
||||
|
||||
for i, banned_tokens in enumerate(banned_batch_tokens):
|
||||
scores[i, banned_tokens] = -float("inf")
|
||||
|
||||
return scores
|
||||
|
||||
def _calc_banned_ngram_tokens(
|
||||
self, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int
|
||||
) -> List[Iterable[int]]:
|
||||
"""Copied from fairseq for no_repeat_ngram in beam_search"""
|
||||
if cur_len + 1 < self.ngram_size:
|
||||
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
||||
return [[] for _ in range(num_hypos)]
|
||||
generated_ngrams = [{} for _ in range(num_hypos)]
|
||||
for idx in range(num_hypos):
|
||||
gen_tokens = prev_input_ids[idx].tolist()
|
||||
generated_ngram = generated_ngrams[idx]
|
||||
for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]):
|
||||
prev_ngram_tuple = tuple(ngram[:-1])
|
||||
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
|
||||
|
||||
def _get_generated_ngrams(hypo_idx):
|
||||
# Before decoding the next token, prevent decoding of ngrams that have already appeared
|
||||
start_idx = cur_len + 1 - self.ngram_size
|
||||
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
|
||||
return generated_ngrams[hypo_idx].get(ngram_idx, [])
|
||||
class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
:class:`transformers.LogitsProcessor` that enforces no repetition of encoder input ids n-grams for the decoder ids.
|
||||
See `ParlAI <https://github.com/facebookresearch/ParlAI/blob/master/parlai/core/torch_generator_agent.py#L1350>`__.
|
||||
|
||||
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
|
||||
return banned_tokens
|
||||
Args:
|
||||
encoder_ngram_size (:obj:`int`):
|
||||
All ngrams of size :obj:`ngram_size` can only occur within the encoder input ids.
|
||||
encoder_input_ids (:obj:`int`):
|
||||
The encoder_input_ids that should not be repeated within the decoder ids.
|
||||
"""
|
||||
|
||||
def __init__(self, encoder_ngram_size: int, encoder_input_ids: torch.LongTensor):
|
||||
if not isinstance(encoder_ngram_size, int) or encoder_ngram_size <= 0:
|
||||
raise ValueError(
|
||||
f"`encoder_ngram_size` has to be a strictly positive integer, but is {encoder_ngram_size}"
|
||||
)
|
||||
self.ngram_size = encoder_ngram_size
|
||||
if len(encoder_input_ids.shape) == 1:
|
||||
encoder_input_ids = encoder_input_ids.unsqueeze(0)
|
||||
self.batch_size = encoder_input_ids.shape[0]
|
||||
self.generated_ngrams = _get_ngrams(encoder_ngram_size, encoder_input_ids, self.batch_size)
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# B x num_beams
|
||||
num_hypos = scores.shape[0]
|
||||
num_beams = num_hypos // self.batch_size
|
||||
cur_len = input_ids.shape[-1]
|
||||
banned_batch_tokens = [
|
||||
_get_generated_ngrams(
|
||||
self.generated_ngrams[hypo_idx // num_beams], input_ids[hypo_idx], self.ngram_size, cur_len
|
||||
)
|
||||
for hypo_idx in range(num_hypos)
|
||||
]
|
||||
|
||||
for i, banned_tokens in enumerate(banned_batch_tokens):
|
||||
scores[i, banned_tokens] = -float("inf")
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
class NoBadWordsLogitsProcessor(LogitsProcessor):
|
||||
|
||||
@@ -23,6 +23,7 @@ from torch.nn import functional as F
|
||||
from .file_utils import ModelOutput
|
||||
from .generation_beam_search import BeamScorer, BeamSearchScorer
|
||||
from .generation_logits_process import (
|
||||
EncoderNoRepeatNGramLogitsProcessor,
|
||||
HammingDiversityLogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
MinLengthLogitsProcessor,
|
||||
@@ -537,6 +538,8 @@ class GenerationMixin:
|
||||
self,
|
||||
repetition_penalty: float,
|
||||
no_repeat_ngram_size: int,
|
||||
encoder_no_repeat_ngram_size: int,
|
||||
encoder_input_ids: torch.LongTensor,
|
||||
bad_words_ids: List[List[int]],
|
||||
min_length: int,
|
||||
eos_token_id: int,
|
||||
@@ -555,6 +558,11 @@ class GenerationMixin:
|
||||
no_repeat_ngram_size = (
|
||||
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
|
||||
)
|
||||
encoder_no_repeat_ngram_size = (
|
||||
encoder_no_repeat_ngram_size
|
||||
if encoder_no_repeat_ngram_size is not None
|
||||
else self.config.encoder_no_repeat_ngram_size
|
||||
)
|
||||
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
|
||||
min_length = min_length if min_length is not None else self.config.min_length
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
@@ -574,6 +582,13 @@ class GenerationMixin:
|
||||
processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
||||
if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
|
||||
processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
|
||||
if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0:
|
||||
if self.config.is_encoder_decoder:
|
||||
processors.append(EncoderNoRepeatNGramLogitsProcessor(encoder_no_repeat_ngram_size, encoder_input_ids))
|
||||
else:
|
||||
raise ValueError(
|
||||
"It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture"
|
||||
)
|
||||
if bad_words_ids is not None:
|
||||
processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
|
||||
if min_length is not None and eos_token_id is not None and min_length > -1:
|
||||
@@ -601,6 +616,7 @@ class GenerationMixin:
|
||||
eos_token_id: Optional[int] = None,
|
||||
length_penalty: Optional[float] = None,
|
||||
no_repeat_ngram_size: Optional[int] = None,
|
||||
encoder_no_repeat_ngram_size: Optional[int] = None,
|
||||
num_return_sequences: Optional[int] = None,
|
||||
decoder_start_token_id: Optional[int] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
@@ -661,6 +677,9 @@ class GenerationMixin:
|
||||
sequences.
|
||||
no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
|
||||
If set to int > 0, all ngrams of that size can only occur once.
|
||||
encoder_no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
|
||||
If set to int > 0, all ngrams of that size that occur in the ``encoder_input_ids`` cannot occur in the
|
||||
``decoder_input_ids``.
|
||||
bad_words_ids(:obj:`List[List[int]]`, `optional`):
|
||||
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
|
||||
should not appear in the generated text, use :obj:`tokenizer(bad_word,
|
||||
@@ -820,6 +839,9 @@ class GenerationMixin:
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||
pad_token_id = eos_token_id
|
||||
|
||||
# Storing encoder_input_ids for logits_processor that could use them
|
||||
encoder_input_ids = input_ids if self.config.is_encoder_decoder else None
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
# add encoder_outputs to model_kwargs
|
||||
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
|
||||
@@ -862,6 +884,8 @@ class GenerationMixin:
|
||||
logits_processor = self._get_logits_processor(
|
||||
repetition_penalty=repetition_penalty,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
|
||||
encoder_input_ids=encoder_input_ids,
|
||||
bad_words_ids=bad_words_ids,
|
||||
min_length=min_length,
|
||||
eos_token_id=eos_token_id,
|
||||
@@ -1638,6 +1662,7 @@ class GenerationMixin:
|
||||
beam_idx = beam_outputs["next_beam_indices"]
|
||||
|
||||
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
||||
|
||||
cur_len = cur_len + 1
|
||||
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
|
||||
@@ -128,6 +128,7 @@ class BlenderbotConfig(PretrainedConfig):
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
encoder_no_repeat_ngram_size=3,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
@@ -136,6 +137,7 @@ class BlenderbotConfig(PretrainedConfig):
|
||||
eos_token_id=eos_token_id,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user