Diverse beam search 2 (#9006)
* diverse beam search * bug fixes * bug fixes * bug fix * separate out diverse_beam_search function * separate out diverse_beam_search function * bug fix * improve code quality * bug fix * bug fix * separate out diverse beam search scorer * code format * code format * code format * code format * add test * code format * documentation changes * code quality * add slow integration tests * more general name * refactor into logits processor * add test * avoid too much copy paste * refactor * add to docs * fix-copies * bug fix * Revert "bug fix" This reverts commit c99eb5a8dc57a7b0d33a8ac06d8c6a32a7812ad4. * improve comment * implement sylvains feedback Co-authored-by: Ayush Jain <a.jain@sprinklr.com> Co-authored-by: ayushtiku5 <40797286+ayushtiku5@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
67ff1c314a
commit
02d0e0355c
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from abc import ABC
|
||||
from typing import Callable, Iterable, List
|
||||
@@ -37,6 +38,8 @@ LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
|
||||
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`):
|
||||
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
|
||||
or scores for each vocabulary token after SoftMax.
|
||||
kwargs:
|
||||
Additional logits processor specific kwargs.
|
||||
|
||||
Return:
|
||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores.
|
||||
@@ -75,9 +78,16 @@ class LogitsProcessorList(list):
|
||||
"""
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
||||
for processor in self:
|
||||
scores = processor(input_ids, scores)
|
||||
function_args = inspect.signature(processor.__call__).parameters
|
||||
if len(function_args) > 2:
|
||||
assert all(
|
||||
arg in kwargs for arg in list(function_args.keys())[2:]
|
||||
), f"Make sure that all the required parameters: {list(function_args.keys())} for {processor.__class__} are passed to the logits processor."
|
||||
scores = processor(input_ids, scores, **kwargs)
|
||||
else:
|
||||
scores = processor(input_ids, scores)
|
||||
return scores
|
||||
|
||||
|
||||
@@ -400,3 +410,65 @@ class PrefixConstrainedLogitsProcessor(LogitsProcessor):
|
||||
mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0
|
||||
|
||||
return scores + mask
|
||||
|
||||
|
||||
class HammingDiversityLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
:class:`transformers.LogitsProcessor` that enforces diverse beam search. Note that this logits processor is only
|
||||
effective for `group_beam_search`. See `Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence Models
|
||||
<https://arxiv.org/pdf/1610.02424.pdf>`__ for more details.
|
||||
|
||||
Args:
|
||||
diversity_penalty (:obj:`float`):
|
||||
This value is subtracted from a beam's score if it generates a token same as any beam from other group at a
|
||||
particular time. Note that :obj:`diversity_penalty` is only effective if ``group beam search`` is enabled.
|
||||
num_beams (:obj:`int`):
|
||||
Number of beams used for group beam search. See `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for
|
||||
more details.
|
||||
num_beam_groups (:obj:`int`):
|
||||
Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of
|
||||
beams. See `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for more details.
|
||||
"""
|
||||
|
||||
def __init__(self, diversity_penalty: float, num_beams: int, num_beam_groups: int):
|
||||
if not isinstance(diversity_penalty, float) or (not diversity_penalty > 0.0):
|
||||
raise ValueError("`diversity_penalty` should be a float strictly larger than 0.")
|
||||
self._diversity_penalty = diversity_penalty
|
||||
if not isinstance(num_beams, int) or num_beams < 2:
|
||||
raise ValueError("`num_beams` should be an integer strictly larger than 1.")
|
||||
self._num_beams = num_beams
|
||||
if not isinstance(num_beam_groups, int) or num_beam_groups < 2:
|
||||
raise ValueError("`num_beam_groups` should be an integer strictly larger than 1.")
|
||||
if num_beam_groups > num_beams:
|
||||
raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`.")
|
||||
if num_beam_groups > num_beams:
|
||||
raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`")
|
||||
self._num_sub_beams = num_beams // num_beam_groups
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
scores: torch.FloatTensor,
|
||||
current_tokens: torch.LongTensor,
|
||||
beam_group_idx: int,
|
||||
) -> torch.FloatTensor:
|
||||
# hamming diversity: penalise using same token in current group which was used in previous groups at
|
||||
# the same time step
|
||||
batch_size = current_tokens.shape[0] // self._num_beams
|
||||
group_start_idx = beam_group_idx * self._num_sub_beams
|
||||
group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams)
|
||||
group_size = group_end_idx - group_start_idx
|
||||
vocab_size = scores.shape[-1]
|
||||
|
||||
if group_start_idx == 0:
|
||||
return scores
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
# predicted tokens of last time step of previous groups
|
||||
previous_group_tokens = current_tokens[
|
||||
batch_idx * self._num_beams : batch_idx * self._num_beams + group_start_idx
|
||||
]
|
||||
token_frequency = torch.bincount(previous_group_tokens, minlength=vocab_size).to(scores.device)
|
||||
scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= self._diversity_penalty * token_frequency
|
||||
|
||||
return scores
|
||||
|
||||
Reference in New Issue
Block a user