Adding new encoder_no_repeat_ngram_size to generate. (#9984)
Adding new `encoder_no_repeat_ngram_size` to `generate`. Blenderbot results seemed off compared to original ParlAI script: `https://parl.ai/projects/recipes/`. Notably the model seems to repeat a lot what was said during the conversation. The actual problem was that `no_repeat_ngram_size` actually applies to the `encoder_input_ids` but HF's `no_repeat_ngram_size` applies to the previously generated ids (within the decoder). The history conversation of blenderbot is within the `encoder` part so that explains why HF's implementation had the repetitions. This fix was focused on blenderbot *not* small and added tests for those because they are quite different in configuration. This change includes: - Adding a new EncoderNoRepeatLogitProcessor. - Adding 1 new arg to `generate` (`encoder_no_repeat_ngram_size`) - Adding 1 new config parameter `encoder_no_repeat_ngram_size`. - Adding 2 tests, one for the pipeline (high level, inputs exhibited repeat behavior, one low level for EncoderNoRepeatLogitProcessor) - Factored NoRepeatLogitProcessor so that logic could be reused. Further work: - Blenderbot conversational pipeline still does not behave correctly as they way input is prepared within the pipeline is still incorrect (follow up PR) - Blenderbot allows the bot to have personas, which is done by prepending "your personna: XXXX" to the input, this could be explored too in a follow up PR. @patrickvonplaten @LysandreJik * Update src/transformers/generation_logits_process.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/generation_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/generation_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/configuration_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Doc quality. * Fixing test. * Last fixes. * Fixing to account for batch_size. * Update src/transformers/configuration_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/generation_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -117,6 +117,9 @@ class PretrainedConfig(object):
|
|||||||
- **no_repeat_ngram_size** (:obj:`int`, `optional`, defaults to 0) -- Value that will be used by default in the
|
- **no_repeat_ngram_size** (:obj:`int`, `optional`, defaults to 0) -- Value that will be used by default in the
|
||||||
:obj:`generate` method of the model for ``no_repeat_ngram_size``. If set to int > 0, all ngrams of that size
|
:obj:`generate` method of the model for ``no_repeat_ngram_size``. If set to int > 0, all ngrams of that size
|
||||||
can only occur once.
|
can only occur once.
|
||||||
|
- **encoder_no_repeat_ngram_size** (:obj:`int`, `optional`, defaults to 0) -- Value that will be used by
|
||||||
|
default in the :obj:`generate` method of the model for ``encoder_no_repeat_ngram_size``. If set to int > 0,
|
||||||
|
all ngrams of that size that occur in the ``encoder_input_ids`` cannot occur in the ``decoder_input_ids``.
|
||||||
- **bad_words_ids** (:obj:`List[int]`, `optional`) -- List of token ids that are not allowed to be generated
|
- **bad_words_ids** (:obj:`List[int]`, `optional`) -- List of token ids that are not allowed to be generated
|
||||||
that will be used by default in the :obj:`generate` method of the model. In order to get the tokens of the
|
that will be used by default in the :obj:`generate` method of the model. In order to get the tokens of the
|
||||||
words that should not appear in the generated text, use :obj:`tokenizer.encode(bad_word,
|
words that should not appear in the generated text, use :obj:`tokenizer.encode(bad_word,
|
||||||
@@ -205,6 +208,7 @@ class PretrainedConfig(object):
|
|||||||
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
|
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
|
||||||
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
||||||
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
|
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
|
||||||
|
self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0)
|
||||||
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
|
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
|
||||||
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
||||||
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
|
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
|
||||||
|
|||||||
@@ -235,6 +235,41 @@ class TopKLogitsWarper(LogitsWarper):
|
|||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
|
||||||
|
generated_ngrams = [{} for _ in range(num_hypos)]
|
||||||
|
for idx in range(num_hypos):
|
||||||
|
gen_tokens = prev_input_ids[idx].tolist()
|
||||||
|
generated_ngram = generated_ngrams[idx]
|
||||||
|
for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
|
||||||
|
prev_ngram_tuple = tuple(ngram[:-1])
|
||||||
|
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
|
||||||
|
return generated_ngrams
|
||||||
|
|
||||||
|
|
||||||
|
def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
|
||||||
|
# Before decoding the next token, prevent decoding of ngrams that have already appeared
|
||||||
|
start_idx = cur_len + 1 - ngram_size
|
||||||
|
ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
|
||||||
|
return banned_ngrams.get(ngram_idx, [])
|
||||||
|
|
||||||
|
|
||||||
|
def _calc_banned_ngram_tokens(
|
||||||
|
ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int
|
||||||
|
) -> List[Iterable[int]]:
|
||||||
|
"""Copied from fairseq for no_repeat_ngram in beam_search"""
|
||||||
|
if cur_len + 1 < ngram_size:
|
||||||
|
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
||||||
|
return [[] for _ in range(num_hypos)]
|
||||||
|
|
||||||
|
generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)
|
||||||
|
|
||||||
|
banned_tokens = [
|
||||||
|
_get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
|
||||||
|
for hypo_idx in range(num_hypos)
|
||||||
|
]
|
||||||
|
return banned_tokens
|
||||||
|
|
||||||
|
|
||||||
class NoRepeatNGramLogitsProcessor(LogitsProcessor):
|
class NoRepeatNGramLogitsProcessor(LogitsProcessor):
|
||||||
r"""
|
r"""
|
||||||
:class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq
|
:class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq
|
||||||
@@ -253,36 +288,53 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor):
|
|||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
num_batch_hypotheses = scores.shape[0]
|
num_batch_hypotheses = scores.shape[0]
|
||||||
cur_len = input_ids.shape[-1]
|
cur_len = input_ids.shape[-1]
|
||||||
banned_batch_tokens = self._calc_banned_ngram_tokens(input_ids, num_batch_hypotheses, cur_len)
|
banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len)
|
||||||
|
|
||||||
for i, banned_tokens in enumerate(banned_batch_tokens):
|
for i, banned_tokens in enumerate(banned_batch_tokens):
|
||||||
scores[i, banned_tokens] = -float("inf")
|
scores[i, banned_tokens] = -float("inf")
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def _calc_banned_ngram_tokens(
|
|
||||||
self, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int
|
|
||||||
) -> List[Iterable[int]]:
|
|
||||||
"""Copied from fairseq for no_repeat_ngram in beam_search"""
|
|
||||||
if cur_len + 1 < self.ngram_size:
|
|
||||||
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
|
||||||
return [[] for _ in range(num_hypos)]
|
|
||||||
generated_ngrams = [{} for _ in range(num_hypos)]
|
|
||||||
for idx in range(num_hypos):
|
|
||||||
gen_tokens = prev_input_ids[idx].tolist()
|
|
||||||
generated_ngram = generated_ngrams[idx]
|
|
||||||
for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]):
|
|
||||||
prev_ngram_tuple = tuple(ngram[:-1])
|
|
||||||
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
|
|
||||||
|
|
||||||
def _get_generated_ngrams(hypo_idx):
|
class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor):
|
||||||
# Before decoding the next token, prevent decoding of ngrams that have already appeared
|
r"""
|
||||||
start_idx = cur_len + 1 - self.ngram_size
|
:class:`transformers.LogitsProcessor` that enforces no repetition of encoder input ids n-grams for the decoder ids.
|
||||||
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
|
See `ParlAI <https://github.com/facebookresearch/ParlAI/blob/master/parlai/core/torch_generator_agent.py#L1350>`__.
|
||||||
return generated_ngrams[hypo_idx].get(ngram_idx, [])
|
|
||||||
|
|
||||||
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
|
Args:
|
||||||
return banned_tokens
|
encoder_ngram_size (:obj:`int`):
|
||||||
|
All ngrams of size :obj:`ngram_size` can only occur within the encoder input ids.
|
||||||
|
encoder_input_ids (:obj:`int`):
|
||||||
|
The encoder_input_ids that should not be repeated within the decoder ids.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, encoder_ngram_size: int, encoder_input_ids: torch.LongTensor):
|
||||||
|
if not isinstance(encoder_ngram_size, int) or encoder_ngram_size <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`encoder_ngram_size` has to be a strictly positive integer, but is {encoder_ngram_size}"
|
||||||
|
)
|
||||||
|
self.ngram_size = encoder_ngram_size
|
||||||
|
if len(encoder_input_ids.shape) == 1:
|
||||||
|
encoder_input_ids = encoder_input_ids.unsqueeze(0)
|
||||||
|
self.batch_size = encoder_input_ids.shape[0]
|
||||||
|
self.generated_ngrams = _get_ngrams(encoder_ngram_size, encoder_input_ids, self.batch_size)
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
# B x num_beams
|
||||||
|
num_hypos = scores.shape[0]
|
||||||
|
num_beams = num_hypos // self.batch_size
|
||||||
|
cur_len = input_ids.shape[-1]
|
||||||
|
banned_batch_tokens = [
|
||||||
|
_get_generated_ngrams(
|
||||||
|
self.generated_ngrams[hypo_idx // num_beams], input_ids[hypo_idx], self.ngram_size, cur_len
|
||||||
|
)
|
||||||
|
for hypo_idx in range(num_hypos)
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, banned_tokens in enumerate(banned_batch_tokens):
|
||||||
|
scores[i, banned_tokens] = -float("inf")
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
class NoBadWordsLogitsProcessor(LogitsProcessor):
|
class NoBadWordsLogitsProcessor(LogitsProcessor):
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from torch.nn import functional as F
|
|||||||
from .file_utils import ModelOutput
|
from .file_utils import ModelOutput
|
||||||
from .generation_beam_search import BeamScorer, BeamSearchScorer
|
from .generation_beam_search import BeamScorer, BeamSearchScorer
|
||||||
from .generation_logits_process import (
|
from .generation_logits_process import (
|
||||||
|
EncoderNoRepeatNGramLogitsProcessor,
|
||||||
HammingDiversityLogitsProcessor,
|
HammingDiversityLogitsProcessor,
|
||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
MinLengthLogitsProcessor,
|
MinLengthLogitsProcessor,
|
||||||
@@ -537,6 +538,8 @@ class GenerationMixin:
|
|||||||
self,
|
self,
|
||||||
repetition_penalty: float,
|
repetition_penalty: float,
|
||||||
no_repeat_ngram_size: int,
|
no_repeat_ngram_size: int,
|
||||||
|
encoder_no_repeat_ngram_size: int,
|
||||||
|
encoder_input_ids: torch.LongTensor,
|
||||||
bad_words_ids: List[List[int]],
|
bad_words_ids: List[List[int]],
|
||||||
min_length: int,
|
min_length: int,
|
||||||
eos_token_id: int,
|
eos_token_id: int,
|
||||||
@@ -555,6 +558,11 @@ class GenerationMixin:
|
|||||||
no_repeat_ngram_size = (
|
no_repeat_ngram_size = (
|
||||||
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
|
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
|
||||||
)
|
)
|
||||||
|
encoder_no_repeat_ngram_size = (
|
||||||
|
encoder_no_repeat_ngram_size
|
||||||
|
if encoder_no_repeat_ngram_size is not None
|
||||||
|
else self.config.encoder_no_repeat_ngram_size
|
||||||
|
)
|
||||||
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
|
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
|
||||||
min_length = min_length if min_length is not None else self.config.min_length
|
min_length = min_length if min_length is not None else self.config.min_length
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||||
@@ -574,6 +582,13 @@ class GenerationMixin:
|
|||||||
processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
||||||
if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
|
if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
|
||||||
processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
|
processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
|
||||||
|
if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0:
|
||||||
|
if self.config.is_encoder_decoder:
|
||||||
|
processors.append(EncoderNoRepeatNGramLogitsProcessor(encoder_no_repeat_ngram_size, encoder_input_ids))
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture"
|
||||||
|
)
|
||||||
if bad_words_ids is not None:
|
if bad_words_ids is not None:
|
||||||
processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
|
processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
|
||||||
if min_length is not None and eos_token_id is not None and min_length > -1:
|
if min_length is not None and eos_token_id is not None and min_length > -1:
|
||||||
@@ -601,6 +616,7 @@ class GenerationMixin:
|
|||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[int] = None,
|
||||||
length_penalty: Optional[float] = None,
|
length_penalty: Optional[float] = None,
|
||||||
no_repeat_ngram_size: Optional[int] = None,
|
no_repeat_ngram_size: Optional[int] = None,
|
||||||
|
encoder_no_repeat_ngram_size: Optional[int] = None,
|
||||||
num_return_sequences: Optional[int] = None,
|
num_return_sequences: Optional[int] = None,
|
||||||
decoder_start_token_id: Optional[int] = None,
|
decoder_start_token_id: Optional[int] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
@@ -661,6 +677,9 @@ class GenerationMixin:
|
|||||||
sequences.
|
sequences.
|
||||||
no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
|
no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
|
||||||
If set to int > 0, all ngrams of that size can only occur once.
|
If set to int > 0, all ngrams of that size can only occur once.
|
||||||
|
encoder_no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
|
||||||
|
If set to int > 0, all ngrams of that size that occur in the ``encoder_input_ids`` cannot occur in the
|
||||||
|
``decoder_input_ids``.
|
||||||
bad_words_ids(:obj:`List[List[int]]`, `optional`):
|
bad_words_ids(:obj:`List[List[int]]`, `optional`):
|
||||||
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
|
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
|
||||||
should not appear in the generated text, use :obj:`tokenizer(bad_word,
|
should not appear in the generated text, use :obj:`tokenizer(bad_word,
|
||||||
@@ -820,6 +839,9 @@ class GenerationMixin:
|
|||||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||||
pad_token_id = eos_token_id
|
pad_token_id = eos_token_id
|
||||||
|
|
||||||
|
# Storing encoder_input_ids for logits_processor that could use them
|
||||||
|
encoder_input_ids = input_ids if self.config.is_encoder_decoder else None
|
||||||
|
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
# add encoder_outputs to model_kwargs
|
# add encoder_outputs to model_kwargs
|
||||||
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
|
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
|
||||||
@@ -862,6 +884,8 @@ class GenerationMixin:
|
|||||||
logits_processor = self._get_logits_processor(
|
logits_processor = self._get_logits_processor(
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||||
|
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
|
||||||
|
encoder_input_ids=encoder_input_ids,
|
||||||
bad_words_ids=bad_words_ids,
|
bad_words_ids=bad_words_ids,
|
||||||
min_length=min_length,
|
min_length=min_length,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
@@ -1638,6 +1662,7 @@ class GenerationMixin:
|
|||||||
beam_idx = beam_outputs["next_beam_indices"]
|
beam_idx = beam_outputs["next_beam_indices"]
|
||||||
|
|
||||||
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
||||||
|
|
||||||
cur_len = cur_len + 1
|
cur_len = cur_len + 1
|
||||||
|
|
||||||
model_kwargs = self._update_model_kwargs_for_generation(
|
model_kwargs = self._update_model_kwargs_for_generation(
|
||||||
|
|||||||
@@ -128,6 +128,7 @@ class BlenderbotConfig(PretrainedConfig):
|
|||||||
pad_token_id=0,
|
pad_token_id=0,
|
||||||
bos_token_id=1,
|
bos_token_id=1,
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
|
encoder_no_repeat_ngram_size=3,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -136,6 +137,7 @@ class BlenderbotConfig(PretrainedConfig):
|
|||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
is_encoder_decoder=is_encoder_decoder,
|
is_encoder_decoder=is_encoder_decoder,
|
||||||
decoder_start_token_id=decoder_start_token_id,
|
decoder_start_token_id=decoder_start_token_id,
|
||||||
|
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ if is_torch_available():
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from transformers.generation_logits_process import (
|
from transformers.generation_logits_process import (
|
||||||
|
EncoderNoRepeatNGramLogitsProcessor,
|
||||||
HammingDiversityLogitsProcessor,
|
HammingDiversityLogitsProcessor,
|
||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
MinLengthLogitsProcessor,
|
MinLengthLogitsProcessor,
|
||||||
@@ -208,6 +209,68 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
torch.isinf(filtered_scores_3_gram).tolist(), [[False, False, False], [True, False, False]]
|
torch.isinf(filtered_scores_3_gram).tolist(), [[False, False, False], [True, False, False]]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_encoder_no_repeat_ngram_dist_processor(self):
|
||||||
|
vocab_size = 3
|
||||||
|
num_beams = 2
|
||||||
|
batch_size = 1
|
||||||
|
|
||||||
|
encoder_input_ids = torch.tensor([1, 2, 1, 1], device=torch_device, dtype=torch.long)
|
||||||
|
|
||||||
|
input_ids = torch.tensor([[1, 2, 1], [8, 0, 2]], device=torch_device, dtype=torch.long)
|
||||||
|
scores = self._get_uniform_logits(batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
|
no_repeat_proc_2_gram = EncoderNoRepeatNGramLogitsProcessor(2, encoder_input_ids=encoder_input_ids)
|
||||||
|
no_repeat_proc_3_gram = EncoderNoRepeatNGramLogitsProcessor(3, encoder_input_ids=encoder_input_ids)
|
||||||
|
|
||||||
|
filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores.clone())
|
||||||
|
filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores.clone())
|
||||||
|
|
||||||
|
# 2-gram would forbid 1st and 2nd token at 1st beam and 1st token (0) at 2nd beam
|
||||||
|
self.assertListEqual(torch.isinf(filtered_scores_2_gram).tolist(), [[False, True, True], [False, True, False]])
|
||||||
|
|
||||||
|
# 3-gram would forbid 1st token at 1st beam and no token at 2nd beam
|
||||||
|
self.assertListEqual(
|
||||||
|
torch.isinf(filtered_scores_3_gram).tolist(), [[False, True, False], [False, False, False]]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Batched input
|
||||||
|
vocab_size = 3
|
||||||
|
num_beams = 2
|
||||||
|
batch_size = 2
|
||||||
|
encoder_input_ids = torch.tensor([[1, 2, 1, 1], [0, 0, 2, 1]], device=torch_device, dtype=torch.long)
|
||||||
|
|
||||||
|
input_ids = torch.tensor([[1, 2, 1], [1, 0, 2], [0, 0, 0], [0, 2, 2]], device=torch_device, dtype=torch.long)
|
||||||
|
scores = self._get_uniform_logits(batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
|
no_repeat_proc_2_gram = EncoderNoRepeatNGramLogitsProcessor(2, encoder_input_ids=encoder_input_ids)
|
||||||
|
no_repeat_proc_3_gram = EncoderNoRepeatNGramLogitsProcessor(3, encoder_input_ids=encoder_input_ids)
|
||||||
|
|
||||||
|
filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores.clone())
|
||||||
|
filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores.clone())
|
||||||
|
|
||||||
|
# 2gram
|
||||||
|
# Batch 1
|
||||||
|
# - Beam 1: tokens (1, 2) forbidden
|
||||||
|
# - Beam 2: tokens (1) forbidden
|
||||||
|
# Batch 2
|
||||||
|
# - Beam 1: tokens (0, 2) forbidden
|
||||||
|
# - Beam 2: tokens (1) forbidden
|
||||||
|
self.assertListEqual(
|
||||||
|
torch.isinf(filtered_scores_2_gram).tolist(),
|
||||||
|
[[False, True, True], [False, True, False], [True, False, True], [False, True, False]],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Batch 1
|
||||||
|
# - Beam 1: tokens (1) forbidden
|
||||||
|
# - Beam 2: tokens () forbidden
|
||||||
|
# Batch 2
|
||||||
|
# - Beam 1: tokens (2) forbidden
|
||||||
|
# - Beam 2: tokens () forbidden
|
||||||
|
self.assertListEqual(
|
||||||
|
torch.isinf(filtered_scores_3_gram).tolist(),
|
||||||
|
[[False, True, False], [False, False, False], [False, False, True], [False, False, False]],
|
||||||
|
)
|
||||||
|
|
||||||
def test_no_bad_words_dist_processor(self):
|
def test_no_bad_words_dist_processor(self):
|
||||||
vocab_size = 5
|
vocab_size = 5
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
|
|||||||
@@ -276,6 +276,47 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
|||||||
self.assertEqual(result.past_user_inputs[1], "Is it an action movie?")
|
self.assertEqual(result.past_user_inputs[1], "Is it an action movie?")
|
||||||
self.assertEqual(result.generated_responses[1], "It's a comedy.")
|
self.assertEqual(result.generated_responses[1], "It's a comedy.")
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@slow
|
||||||
|
def test_integration_torch_conversation_blenderbot_400M(self):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-400M-distill")
|
||||||
|
nlp = ConversationalPipeline(model=model, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
conversation_1 = Conversation("hello")
|
||||||
|
result = nlp(
|
||||||
|
conversation_1,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
result.generated_responses[0],
|
||||||
|
# ParlAI implementation output, we have a different one, but it's our
|
||||||
|
# second best, you can check by using num_return_sequences=10
|
||||||
|
# " Hello! How are you? I'm just getting ready to go to work, how about you?",
|
||||||
|
" Hello! How are you doing today? I just got back from a walk with my dog.",
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_1 = Conversation(" Lasagne hello")
|
||||||
|
result = nlp(conversation_1, encoder_no_repeat_ngram_size=3)
|
||||||
|
self.assertEqual(
|
||||||
|
result.generated_responses[0],
|
||||||
|
" Lasagne is my favorite Italian dish. Do you like lasagne?",
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_1 = Conversation(
|
||||||
|
"Lasagne hello Lasagne is my favorite Italian dish. Do you like lasagne? I like lasagne."
|
||||||
|
)
|
||||||
|
result = nlp(
|
||||||
|
conversation_1,
|
||||||
|
encoder_no_repeat_ngram_size=3,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
result.generated_responses[0],
|
||||||
|
# ParlAI implementation output, we have a different one, but it's our
|
||||||
|
# second best, you can check by using num_return_sequences=10
|
||||||
|
# " Hello! How are you? I'm just getting ready to go to work, how about you?",
|
||||||
|
" Lasagne is a traditional Italian dish consisting of a yeasted flatbread typically topped with tomato sauce and cheese.",
|
||||||
|
)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
def test_integration_torch_conversation_encoder_decoder(self):
|
def test_integration_torch_conversation_encoder_decoder(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user