Diverse beam search 2 (#9006)

* diverse beam search

* bug fixes

* bug fixes

* bug fix

* separate out diverse_beam_search function

* separate out diverse_beam_search function

* bug fix

* improve code quality

* bug fix

* bug fix

* separate out diverse beam search scorer

* code format

* code format

* code format

* code format

* add test

* code format

* documentation changes

* code quality

* add slow integration tests

* more general name

* refactor into logits processor

* add test

* avoid too much copy paste

* refactor

* add to docs

* fix-copies

* bug fix

* Revert "bug fix"

This reverts commit c99eb5a8dc57a7b0d33a8ac06d8c6a32a7812ad4.

* improve comment

* implement sylvains feedback

Co-authored-by: Ayush Jain <a.jain@sprinklr.com>
Co-authored-by: ayushtiku5 <40797286+ayushtiku5@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2020-12-09 15:00:37 +01:00
committed by GitHub
parent 67ff1c314a
commit 02d0e0355c
10 changed files with 590 additions and 20 deletions

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
from abc import ABC
from typing import Callable, Iterable, List
@@ -37,6 +38,8 @@ LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
or scores for each vocabulary token after SoftMax.
kwargs:
Additional logits processor specific kwargs.
Return:
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores.
@@ -75,9 +78,16 @@ class LogitsProcessorList(list):
"""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
for processor in self:
scores = processor(input_ids, scores)
function_args = inspect.signature(processor.__call__).parameters
if len(function_args) > 2:
assert all(
arg in kwargs for arg in list(function_args.keys())[2:]
), f"Make sure that all the required parameters: {list(function_args.keys())} for {processor.__class__} are passed to the logits processor."
scores = processor(input_ids, scores, **kwargs)
else:
scores = processor(input_ids, scores)
return scores
@@ -400,3 +410,65 @@ class PrefixConstrainedLogitsProcessor(LogitsProcessor):
mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0
return scores + mask
class HammingDiversityLogitsProcessor(LogitsProcessor):
r"""
:class:`transformers.LogitsProcessor` that enforces diverse beam search. Note that this logits processor is only
effective for `group_beam_search`. See `Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence Models
<https://arxiv.org/pdf/1610.02424.pdf>`__ for more details.
Args:
diversity_penalty (:obj:`float`):
This value is subtracted from a beam's score if it generates a token same as any beam from other group at a
particular time. Note that :obj:`diversity_penalty` is only effective if ``group beam search`` is enabled.
num_beams (:obj:`int`):
Number of beams used for group beam search. See `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for
more details.
num_beam_groups (:obj:`int`):
Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of
beams. See `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for more details.
"""
def __init__(self, diversity_penalty: float, num_beams: int, num_beam_groups: int):
if not isinstance(diversity_penalty, float) or (not diversity_penalty > 0.0):
raise ValueError("`diversity_penalty` should be a float strictly larger than 0.")
self._diversity_penalty = diversity_penalty
if not isinstance(num_beams, int) or num_beams < 2:
raise ValueError("`num_beams` should be an integer strictly larger than 1.")
self._num_beams = num_beams
if not isinstance(num_beam_groups, int) or num_beam_groups < 2:
raise ValueError("`num_beam_groups` should be an integer strictly larger than 1.")
if num_beam_groups > num_beams:
raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`.")
if num_beam_groups > num_beams:
raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`")
self._num_sub_beams = num_beams // num_beam_groups
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
current_tokens: torch.LongTensor,
beam_group_idx: int,
) -> torch.FloatTensor:
# hamming diversity: penalise using same token in current group which was used in previous groups at
# the same time step
batch_size = current_tokens.shape[0] // self._num_beams
group_start_idx = beam_group_idx * self._num_sub_beams
group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams)
group_size = group_end_idx - group_start_idx
vocab_size = scores.shape[-1]
if group_start_idx == 0:
return scores
for batch_idx in range(batch_size):
# predicted tokens of last time step of previous groups
previous_group_tokens = current_tokens[
batch_idx * self._num_beams : batch_idx * self._num_beams + group_start_idx
]
token_frequency = torch.bincount(previous_group_tokens, minlength=vocab_size).to(scores.device)
scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= self._diversity_penalty * token_frequency
return scores