Diverse beam search 2 (#9006)
* diverse beam search * bug fixes * bug fixes * bug fix * separate out diverse_beam_search function * separate out diverse_beam_search function * bug fix * improve code quality * bug fix * bug fix * separate out diverse beam search scorer * code format * code format * code format * code format * add test * code format * documentation changes * code quality * add slow integration tests * more general name * refactor into logits processor * add test * avoid too much copy paste * refactor * add to docs * fix-copies * bug fix * Revert "bug fix" This reverts commit c99eb5a8dc57a7b0d33a8ac06d8c6a32a7812ad4. * improve comment * implement sylvains feedback Co-authored-by: Ayush Jain <a.jain@sprinklr.com> Co-authored-by: ayushtiku5 <40797286+ayushtiku5@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
67ff1c314a
commit
02d0e0355c
@@ -22,6 +22,7 @@ from torch.nn import functional as F
|
||||
from .file_utils import ModelOutput
|
||||
from .generation_beam_search import BeamScorer, BeamSearchScorer
|
||||
from .generation_logits_process import (
|
||||
HammingDiversityLogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
MinLengthLogitsProcessor,
|
||||
NoBadWordsLogitsProcessor,
|
||||
@@ -261,6 +262,8 @@ class GenerationMixin:
|
||||
eos_token_id: int,
|
||||
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
|
||||
num_beams: int,
|
||||
num_beam_groups: int,
|
||||
diversity_penalty: float,
|
||||
) -> LogitsProcessorList:
|
||||
"""
|
||||
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
|
||||
@@ -275,11 +278,18 @@ class GenerationMixin:
|
||||
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
|
||||
min_length = min_length if min_length is not None else self.config.min_length
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty
|
||||
# instantiate processors list
|
||||
processors = LogitsProcessorList()
|
||||
|
||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||
# all samplers can be found in `generation_utils_samplers.py`
|
||||
if diversity_penalty is not None and diversity_penalty > 0.0:
|
||||
processors.append(
|
||||
HammingDiversityLogitsProcessor(
|
||||
diversity_penalty=diversity_penalty, num_beams=num_beams, num_beam_groups=num_beam_groups
|
||||
)
|
||||
)
|
||||
if repetition_penalty is not None and repetition_penalty != 1.0:
|
||||
processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
||||
if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
|
||||
@@ -314,6 +324,8 @@ class GenerationMixin:
|
||||
num_return_sequences: Optional[int] = None,
|
||||
decoder_start_token_id: Optional[int] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
num_beam_groups: Optional[int] = None,
|
||||
diversity_penalty: Optional[float] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||
**model_kwargs
|
||||
) -> torch.LongTensor:
|
||||
@@ -381,6 +393,13 @@ class GenerationMixin:
|
||||
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
||||
speed up decoding.
|
||||
num_beam_groups (:obj:`int`, `optional`, defaults to 1):
|
||||
Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of
|
||||
beams. `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for more details.
|
||||
diversity_penalty (:obj:`float`, `optional`, defaults to 0.0):
|
||||
This value is subtracted from a beam's score if it generates a token same as any beam from other group
|
||||
at a particular time. Note that :obj:`diversity_penalty` is only effective if ``group beam search`` is
|
||||
enabled.
|
||||
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`):
|
||||
If provided, this function constraints the beam search to allowed tokens only at each step. If not
|
||||
provided no constraint is applied. This function takes 2 arguments :obj:`inputs_ids` and the batch ID
|
||||
@@ -453,6 +472,7 @@ class GenerationMixin:
|
||||
|
||||
# set init values
|
||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||
num_return_sequences = (
|
||||
@@ -491,10 +511,17 @@ class GenerationMixin:
|
||||
raise ValueError("Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`.")
|
||||
|
||||
# determine generation mode
|
||||
is_greedy_gen_mode = (num_beams == 1) and do_sample is False
|
||||
is_sample_gen_mode = (num_beams == 1) and do_sample is True
|
||||
is_beam_gen_mode = (num_beams > 1) and do_sample is False
|
||||
is_beam_sample_gen_mode = (num_beams > 1) and do_sample is True
|
||||
is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False
|
||||
is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True
|
||||
is_beam_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is False
|
||||
is_beam_sample_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is True
|
||||
is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1)
|
||||
if num_beam_groups > num_beams:
|
||||
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
|
||||
if is_group_beam_gen_mode and do_sample is True:
|
||||
raise ValueError(
|
||||
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
|
||||
)
|
||||
|
||||
# set model_kwargs
|
||||
model_kwargs["use_cache"] = use_cache
|
||||
@@ -508,6 +535,8 @@ class GenerationMixin:
|
||||
eos_token_id=eos_token_id,
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
num_beams=num_beams,
|
||||
num_beam_groups=num_beam_groups,
|
||||
diversity_penalty=diversity_penalty,
|
||||
)
|
||||
|
||||
if is_greedy_gen_mode:
|
||||
@@ -619,6 +648,42 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
elif is_group_beam_gen_mode:
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
||||
|
||||
if num_return_sequences > num_beams:
|
||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||
|
||||
if num_beams % num_beam_groups != 0:
|
||||
raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")
|
||||
|
||||
diverse_beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
device=self.device,
|
||||
length_penalty=length_penalty,
|
||||
do_early_stopping=early_stopping,
|
||||
num_beam_hyps_to_keep=num_return_sequences,
|
||||
num_beam_groups=num_beam_groups,
|
||||
)
|
||||
# interleave with `num_beams`
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
|
||||
)
|
||||
return self.group_beam_search(
|
||||
input_ids,
|
||||
diverse_beam_scorer,
|
||||
logits_processor=logits_processor,
|
||||
max_length=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
def greedy_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
@@ -1208,6 +1273,213 @@ class GenerationMixin:
|
||||
|
||||
return decoded
|
||||
|
||||
def group_beam_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
beam_scorer: BeamScorer,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
**model_kwargs
|
||||
):
|
||||
r"""
|
||||
Generates sequences for models with a language modeling head using beam search decoding.
|
||||
|
||||
Parameters:
|
||||
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
|
||||
:obj:`torch.LongTensor` of shape :obj:`(1,)`.
|
||||
beam_scorer (:obj:`BeamScorer`):
|
||||
An derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are
|
||||
constructed, stored and sorted during generation. For more information, the documentation of
|
||||
:class:`~transformers.BeamScorer` should be read.
|
||||
logits_processor (:obj:`LogitsProcessorList`, `optional`):
|
||||
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
|
||||
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
|
||||
head applied at each generation step.
|
||||
max_length (:obj:`int`, `optional`, defaults to 20):
|
||||
The maximum length of the sequence to be generated.
|
||||
pad_token_id (:obj:`int`, `optional`):
|
||||
The id of the `padding` token.
|
||||
eos_token_id (:obj:`int`, `optional`):
|
||||
The id of the `end-of-sequence` token.
|
||||
model_kwargs:
|
||||
Additional model specific kwargs that will be forwarded to the :obj:`forward` function of the model. If
|
||||
model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
|
||||
|
||||
Return:
|
||||
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
|
||||
sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all
|
||||
batches finished early due to the :obj:`eos_token_id`.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import (
|
||||
... AutoTokenizer,
|
||||
... AutoModelForSeq2SeqLM,
|
||||
... LogitsProcessorList,
|
||||
... MinLengthLogitsProcessor,
|
||||
... HammingDiversityLogitsProcessor,
|
||||
... BeamSearchScorer,
|
||||
... )
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
||||
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
||||
|
||||
>>> encoder_input_str = "translate English to German: How old are you?"
|
||||
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
|
||||
|
||||
|
||||
>>> # lets run diverse beam search using 6 beams
|
||||
>>> num_beams = 6
|
||||
>>> # define decoder start token ids
|
||||
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
|
||||
>>> input_ids = input_ids * model.config.decoder_start_token_id
|
||||
|
||||
>>> # add encoder_outputs to model keyword arguments
|
||||
>>> model_kwargs = {
|
||||
... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True)
|
||||
... }
|
||||
|
||||
>>> # instantiate beam scorer
|
||||
>>> beam_scorer = BeamSearchScorer(
|
||||
... batch_size=1,
|
||||
... max_length=model.config.max_length,
|
||||
... num_beams=num_beams,
|
||||
... device=model.device,
|
||||
... num_beam_groups=3
|
||||
... )
|
||||
|
||||
>>> # instantiate logits processors
|
||||
>>> logits_processor = LogitsProcessorList([
|
||||
... HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3),
|
||||
... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
|
||||
... ])
|
||||
|
||||
>>> outputs = model.group_beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
|
||||
|
||||
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
"""
|
||||
|
||||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
|
||||
batch_size = len(beam_scorer._beam_hyps)
|
||||
num_beams = beam_scorer.num_beams
|
||||
num_beam_groups = beam_scorer.num_beam_groups
|
||||
num_sub_beams = num_beams // num_beam_groups
|
||||
device = input_ids.device
|
||||
|
||||
batch_beam_size, cur_len = input_ids.shape
|
||||
|
||||
assert (
|
||||
num_beams * batch_size == batch_beam_size
|
||||
), f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
||||
|
||||
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
|
||||
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
|
||||
# the same group don't produce same tokens everytime.
|
||||
beam_scores[:, ::num_sub_beams] = 0
|
||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||
|
||||
while cur_len < max_length:
|
||||
# predicted tokens in cur_len step
|
||||
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
|
||||
|
||||
# indices which will form the beams in the next time step
|
||||
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
|
||||
|
||||
# do one decoder step on all beams of all sentences in batch
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
outputs = self(**model_inputs, return_dict=True)
|
||||
|
||||
for beam_group_idx in range(num_beam_groups):
|
||||
group_start_idx = beam_group_idx * num_sub_beams
|
||||
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
|
||||
group_size = group_end_idx - group_start_idx
|
||||
|
||||
# indices of beams of current group among all sentences in batch
|
||||
batch_group_indices = []
|
||||
for batch_idx in range(batch_size):
|
||||
batch_group_indices.extend(
|
||||
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
|
||||
)
|
||||
group_input_ids = input_ids[batch_group_indices]
|
||||
|
||||
# select outputs of beams of current group only
|
||||
next_token_logits = outputs.logits[batch_group_indices, -1, :]
|
||||
|
||||
# adjust tokens for Bart, *e.g.*
|
||||
next_token_logits = self.adjust_logits_during_generation(
|
||||
next_token_logits, cur_len=cur_len, max_length=max_length
|
||||
)
|
||||
|
||||
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size)
|
||||
vocab_size = next_token_scores.shape[-1]
|
||||
|
||||
next_token_scores = logits_processor(
|
||||
group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx
|
||||
)
|
||||
next_token_scores = next_token_scores + beam_scores[batch_group_indices].unsqueeze(-1).expand_as(
|
||||
next_token_scores
|
||||
)
|
||||
# reshape for beam search
|
||||
|
||||
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
|
||||
|
||||
next_token_scores, next_tokens = torch.topk(
|
||||
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
|
||||
)
|
||||
|
||||
next_indices = next_tokens // vocab_size
|
||||
next_tokens = next_tokens % vocab_size
|
||||
|
||||
# stateless
|
||||
beam_outputs = beam_scorer.process(
|
||||
group_input_ids,
|
||||
next_token_scores,
|
||||
next_tokens,
|
||||
next_indices,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
)
|
||||
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
|
||||
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
||||
beam_idx = beam_outputs["next_beam_indices"]
|
||||
|
||||
input_ids[batch_group_indices] = group_input_ids[beam_idx]
|
||||
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
||||
current_tokens[batch_group_indices] = group_input_ids[:, -1]
|
||||
|
||||
# (beam_idx // group_size) -> batch_idx
|
||||
# (beam_idx % group_size) -> offset of idx inside the group
|
||||
reordering_indices[batch_group_indices] = (
|
||||
num_beams * (beam_idx // group_size) + group_start_idx + (beam_idx % group_size)
|
||||
)
|
||||
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
if model_kwargs["past"] is not None:
|
||||
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], reordering_indices)
|
||||
|
||||
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
|
||||
cur_len = cur_len + 1
|
||||
if beam_scorer.is_done:
|
||||
break
|
||||
|
||||
decoded = beam_scorer.finalize(
|
||||
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
|
||||
)
|
||||
|
||||
return decoded
|
||||
|
||||
|
||||
def top_k_top_p_filtering(
|
||||
logits: torch.FloatTensor,
|
||||
|
||||
Reference in New Issue
Block a user