Constrained Beam Search [without disjunctive decoding] (#15416)
* added classes to get started with constrained beam search * in progress, think i can directly force tokens now but not yet with the round robin * think now i have total control, now need to code the bank selection * technically works as desired, need to optimize and fix design choices leading to undersirable outputs * complete PR #1 without disjunctive decoding * removed incorrect tests * Delete k.txt * Delete test.py * Delete test.sh * revert changes to test scripts * genutils * full implementation with testing, no disjunctive yet * shifted docs * passing all tests realistically ran locally * removing accidentally included print statements * fixed source of error in initial PR test * fixing the get_device() vs device trap * fixed documentation docstrings about constrained_beam_search * fixed tests having failing for Speech2TextModel's floating point inputs * fix cuda long tensor * added examples and testing for them and founx & fixed a bug in beam_search and constrained_beam_search * deleted accidentally added test halting code with assert False * code reformat * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update tests/test_generation_utils.py * fixing based on comments on PR * took out the testing code that should but work fails without the beam search moditification ; style changes * fixing comments issues * docstrings for ConstraintListState * typo in PhrsalConstraint docstring * docstrings improvements Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -16,8 +16,9 @@ This page lists all the utility functions used by [`~generation_utils.Generation
|
||||
[`~generation_utils.GenerationMixin.greedy_search`],
|
||||
[`~generation_utils.GenerationMixin.sample`],
|
||||
[`~generation_utils.GenerationMixin.beam_search`],
|
||||
[`~generation_utils.GenerationMixin.beam_sample`], and
|
||||
[`~generation_utils.GenerationMixin.group_beam_search`].
|
||||
[`~generation_utils.GenerationMixin.beam_sample`],
|
||||
[`~generation_utils.GenerationMixin.group_beam_search`], and
|
||||
[`~generation_utils.GenerationMixin.constrained_beam_search`].
|
||||
|
||||
Most of those are only useful if you are studying the code of the generate methods in the library.
|
||||
|
||||
@@ -190,6 +191,16 @@ A [`StoppingCriteria`] can be used to change when to stop generation (other than
|
||||
[[autodoc]] MaxTimeCriteria
|
||||
- __call__
|
||||
|
||||
## Constraints
|
||||
|
||||
A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output.
|
||||
|
||||
[[autodoc]] Constraint
|
||||
|
||||
[[autodoc]] PhrasalConstraint
|
||||
|
||||
[[autodoc]] ConstraintListState
|
||||
|
||||
## BeamSearch
|
||||
|
||||
[[autodoc]] BeamScorer
|
||||
@@ -200,6 +211,10 @@ A [`StoppingCriteria`] can be used to change when to stop generation (other than
|
||||
- process
|
||||
- finalize
|
||||
|
||||
[[autodoc]] ConstrainedBeamSearchScorer
|
||||
- process
|
||||
- finalize
|
||||
|
||||
## Utilities
|
||||
|
||||
[[autodoc]] top_k_top_p_filtering
|
||||
|
||||
@@ -612,7 +612,12 @@ if is_torch_available():
|
||||
"TextDatasetForNextSentencePrediction",
|
||||
]
|
||||
_import_structure["deepspeed"] = []
|
||||
_import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer"]
|
||||
_import_structure["generation_beam_constraints"] = [
|
||||
"Constraint",
|
||||
"ConstraintListState",
|
||||
"PhrasalConstraint",
|
||||
]
|
||||
_import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer", "ConstrainedBeamSearchScorer"]
|
||||
_import_structure["generation_logits_process"] = [
|
||||
"ForcedBOSTokenLogitsProcessor",
|
||||
"ForcedEOSTokenLogitsProcessor",
|
||||
@@ -2750,7 +2755,8 @@ if TYPE_CHECKING:
|
||||
TextDataset,
|
||||
TextDatasetForNextSentencePrediction,
|
||||
)
|
||||
from .generation_beam_search import BeamScorer, BeamSearchScorer
|
||||
from .generation_beam_constraints import Constraint, ConstraintListState, PhrasalConstraint
|
||||
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||
from .generation_logits_process import (
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
|
||||
367
src/transformers/generation_beam_constraints.py
Normal file
367
src/transformers/generation_beam_constraints.py
Normal file
@@ -0,0 +1,367 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Constraint(ABC):
|
||||
r"""Abstract base class for all constraints that can be applied during generation.
|
||||
It must define how the constraint can be satisfied.
|
||||
|
||||
All classes that inherit Constraint must follow the requirement that
|
||||
|
||||
```py
|
||||
completed = False
|
||||
while not completed:
|
||||
_, completed = constraint.update(constraint.advance())
|
||||
```
|
||||
|
||||
will always terminate (halt).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# test for the above condition
|
||||
self.test()
|
||||
|
||||
def test(self):
|
||||
"""
|
||||
Tests whether this constraint has been properly defined.
|
||||
"""
|
||||
counter = 0
|
||||
completed = False
|
||||
while not completed:
|
||||
if counter == 1:
|
||||
self.reset()
|
||||
advance = self.advance()
|
||||
if not self.does_advance(advance):
|
||||
raise Exception(
|
||||
"Custom Constraint is not defined correctly. self.does_advance(self.advance()) must be true."
|
||||
)
|
||||
|
||||
stepped, completed, reset = self.update(advance)
|
||||
counter += 1
|
||||
|
||||
if counter > 10000:
|
||||
raise Exception("update() does not fulfill the constraint.")
|
||||
|
||||
if self.remaining() != 0:
|
||||
raise Exception("Custom Constraint is not defined correctly.")
|
||||
|
||||
@abstractmethod
|
||||
def advance(self):
|
||||
"""
|
||||
When called, returns the token that would take this constraint one step closer to being fulfilled.
|
||||
|
||||
Return:
|
||||
token_ids(`torch.tensor`): Must be a tensor of a list of indexable tokens, not some integer.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def does_advance(self, token_id: int):
|
||||
"""
|
||||
Reads in a token and returns whether it creates progress.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def update(self, token_id: int):
|
||||
"""
|
||||
Reads in a token and returns booleans that indicate the progress made by it. This function will update the
|
||||
state of this object unlikes `does_advance(self, token_id: int)`.
|
||||
|
||||
This isn't to test whether a certain token will advance the progress; it's to update its state as if it has
|
||||
been generated. This becomes important if token_id != desired token (refer to else statement in
|
||||
PhrasalConstraint)
|
||||
|
||||
Args:
|
||||
token_id(`int`):
|
||||
The id of a newly generated token in the beam search.
|
||||
Return:
|
||||
stepped(`bool`):
|
||||
Whether this constraint has become one step closer to being fulfuilled.
|
||||
completed(`bool`):
|
||||
Whether this constraint has been completely fulfilled by this token being generated.
|
||||
reset (`bool`):
|
||||
Whether this constraint has reset its progress by this token being generated.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
"""
|
||||
Resets the state of this constraint to its initialization. We would call this in cases where the fulfillment of
|
||||
a constraint is abrupted by an unwanted token.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def remaining(self):
|
||||
"""
|
||||
Returns the number of remaining steps of `advance()` in order to complete this constraint.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def copy(self, stateful=False):
|
||||
"""
|
||||
Creates a new instance of this constraint.
|
||||
|
||||
Args:
|
||||
stateful(`bool`): Whether to not only copy the constraint for new instance, but also its state.
|
||||
|
||||
Return:
|
||||
constraint(`Constraint`): The same constraint as the one being called from.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
||||
)
|
||||
|
||||
|
||||
class PhrasalConstraint(Constraint):
|
||||
r"""
|
||||
[`Constraint`] enforcing that an ordered sequence of tokens is included in the output.
|
||||
|
||||
Args:
|
||||
token_ids (`List[int]`):
|
||||
The id of the token that must be generated by the output.
|
||||
"""
|
||||
|
||||
def __init__(self, token_ids: Union[List[int], torch.LongTensor]):
|
||||
super(Constraint, self).__init__()
|
||||
|
||||
is_int_list = isinstance(token_ids, List) and isinstance(token_ids[0], int)
|
||||
is_tensor = isinstance(token_ids, torch.Tensor)
|
||||
is_int_tensor = (
|
||||
is_tensor and token_ids.dtype in [torch.int16, torch.int32, torch.int64] and len(token_ids.size()) == 1
|
||||
)
|
||||
not_positive = torch.any(token_ids < 0) if is_tensor else len([t for t in token_ids if t < 0]) > 0
|
||||
if isinstance(token_ids, int) or not (is_int_list or is_int_tensor) or not_positive:
|
||||
raise ValueError(f"`token_ids` has to be a single list or tensor of positive integers but is {token_ids}")
|
||||
|
||||
if not is_tensor:
|
||||
token_ids = torch.tensor(token_ids)
|
||||
|
||||
self.token_ids = token_ids
|
||||
|
||||
self.seqlen = self.token_ids.size(0)
|
||||
self.fulfilled_idx = -1 # the index of the currently fulfilled step
|
||||
self.completed = False
|
||||
|
||||
def advance(self):
|
||||
return self.token_ids[self.fulfilled_idx + 1]
|
||||
|
||||
def does_advance(self, token_id: int):
|
||||
if self.completed:
|
||||
return False
|
||||
# move to cpu to guarantee no device issues.
|
||||
return token_id.cpu() == self.token_ids[self.fulfilled_idx + 1].cpu()
|
||||
|
||||
def update(self, token_id: int):
|
||||
stepped = False
|
||||
completed = False
|
||||
reset = False
|
||||
|
||||
if self.does_advance(token_id):
|
||||
self.fulfilled_idx += 1
|
||||
stepped = True
|
||||
if self.fulfilled_idx == (self.seqlen - 1):
|
||||
completed = True
|
||||
self.completed = completed
|
||||
else:
|
||||
# failed to make progress.
|
||||
reset = True
|
||||
self.reset()
|
||||
return stepped, completed, reset
|
||||
|
||||
def reset(self):
|
||||
self.completed = False
|
||||
self.fulfilled_idx = 0
|
||||
|
||||
def remaining(self):
|
||||
return self.seqlen - (self.fulfilled_idx + 1)
|
||||
|
||||
def copy(self, stateful=False):
|
||||
new_constraint = PhrasalConstraint(self.token_ids)
|
||||
|
||||
if stateful:
|
||||
new_constraint.seq_len = self.seqlen
|
||||
new_constraint.fulfilled_idx = self.fulfilled_idx
|
||||
new_constraint.completed = self.completed
|
||||
|
||||
return new_constraint
|
||||
|
||||
|
||||
class ConstraintListState:
|
||||
r"""
|
||||
A class for beam scorers to track its progress through a list of constraints.
|
||||
|
||||
Args:
|
||||
constraints (`List[Constraint]`):
|
||||
A list of [`Constraint`] objects that must be fulfilled by the beam scorer.
|
||||
"""
|
||||
|
||||
def __init__(self, constraints: List[Constraint]):
|
||||
self.constraints = constraints
|
||||
|
||||
# max # of steps required to fulfill a given constraint
|
||||
self.max_seqlen = max([c.seqlen for c in constraints if isinstance(c, PhrasalConstraint)])
|
||||
self.n_constraints = len(constraints)
|
||||
self.completed = False
|
||||
|
||||
self.init_state()
|
||||
|
||||
def init_state(self):
|
||||
self.complete_constraints = []
|
||||
self.inprogress_constraint = None
|
||||
self.pending_constraints = [constraint.copy(stateful=False) for constraint in self.constraints]
|
||||
|
||||
def get_bank(self):
|
||||
add = 0
|
||||
if self.inprogress_constraint:
|
||||
# extra points for having a constraint mid-fulfilled
|
||||
add += self.max_seqlen - self.inprogress_constraint.remaining()
|
||||
|
||||
return (len(self.complete_constraints) * self.max_seqlen) + add
|
||||
|
||||
def advance(self):
|
||||
"""The list of tokens to generate such that we can make progress.
|
||||
By "list" we don't mean the list of token that will fully fulfill a constraint.
|
||||
|
||||
Given constraints `c_i = {t_ij | j == # of tokens}`, If we're not in the middle of progressing through a
|
||||
specific constraint `c_i`, we return:
|
||||
|
||||
`[t_k1 for k in indices of unfulfilled constraints]`
|
||||
|
||||
If we are in the middle of a constraint, then we return:
|
||||
`[t_ij]`, where `i` is the index of the inprogress constraint, `j` is the next step for the constraint.
|
||||
|
||||
Though we don't care which constraint is fulfilled first, if we are in the progress of fulfilling a constraint,
|
||||
that's the only one we'll return.
|
||||
"""
|
||||
if self.inprogress_constraint is None:
|
||||
token_list = []
|
||||
for constraint in self.pending_constraints: # "pending" == "unfulfilled yet"
|
||||
advance = constraint.advance()
|
||||
token_list.append(advance)
|
||||
else:
|
||||
token_list = [self.inprogress_constraint.advance()]
|
||||
|
||||
if len(token_list) == 0:
|
||||
return None
|
||||
else:
|
||||
return torch.stack(token_list)
|
||||
|
||||
def reset(self, token_ids: Optional[torch.LongTensor]):
|
||||
"""
|
||||
token_ids: the tokens generated thus far to reset the state of the progress through constraints.
|
||||
"""
|
||||
self.init_state()
|
||||
|
||||
if token_ids is not None and token_ids.size(0) > 0:
|
||||
for token in token_ids:
|
||||
# completes or steps **one** constraint
|
||||
complete, stepped = self.add(token)
|
||||
|
||||
# the entire list of constraints are fulfilled
|
||||
if self.completed:
|
||||
break
|
||||
|
||||
return self
|
||||
|
||||
def add(self, token_id: Union[int, torch.LongTensor]):
|
||||
complete, stepped = False, False
|
||||
|
||||
if self.completed:
|
||||
complete = True
|
||||
stepped = False
|
||||
return complete, stepped
|
||||
|
||||
if self.inprogress_constraint is not None:
|
||||
# In the middle of fulfilling a constraint. If the `token_id` *does* makes an incremental progress to current
|
||||
# job, simply update the state
|
||||
|
||||
stepped, complete, reset = self.inprogress_constraint.update(token_id)
|
||||
if reset:
|
||||
# 1. If the next token breaks the progress, then we must restart.
|
||||
# e.g. constraint = "I love pies" and sequence so far is "I love" but `token_id` == "books".
|
||||
|
||||
# But that doesn't mean we self.init_state(), since we only reset the state for this particular
|
||||
# constraint, not the full list of constraints.
|
||||
|
||||
self.pending_constraints.append(self.inprogress_constraint.copy(stateful=False))
|
||||
self.inprogress_constraint = None
|
||||
|
||||
if complete:
|
||||
# 2. If the next token completes the constraint, move it to completed list, set
|
||||
# inprogress to None. If there are no pending constraints either, then this full list of constraints
|
||||
# is complete.
|
||||
|
||||
self.complete_constraints.append(self.inprogress_constraint)
|
||||
self.inprogress_constraint = None
|
||||
|
||||
if len(self.pending_constraints) == 0:
|
||||
# we're done!
|
||||
self.completed = True
|
||||
|
||||
else:
|
||||
# Not in the middle of fulfilling a constraint. So does this `token_id` helps us step towards any of our list
|
||||
# of constraints?
|
||||
|
||||
for cidx, pending_constraint in enumerate(self.pending_constraints):
|
||||
if pending_constraint.does_advance(token_id):
|
||||
stepped, complete, reset = pending_constraint.update(token_id)
|
||||
|
||||
if not stepped:
|
||||
raise Exception(
|
||||
"constraint.update(token_id) is not yielding incremental progress, "
|
||||
"even though constraint.does_advance(token_id) is true."
|
||||
)
|
||||
|
||||
if complete:
|
||||
self.complete_constraints.append(pending_constraint)
|
||||
self.inprogress_constraint = None
|
||||
|
||||
if not complete and stepped:
|
||||
self.inprogress_constraint = pending_constraint
|
||||
|
||||
if complete or stepped:
|
||||
# If we made any progress at all, then it's at least not a "pending constraint".
|
||||
|
||||
self.pending_constraints = (
|
||||
self.pending_constraints[:cidx] + self.pending_constraints[cidx + 1 :]
|
||||
)
|
||||
|
||||
if len(self.pending_constraints) == 0 and self.inprogress_constraint is None:
|
||||
# If there's no longer any pending after this and no inprogress either, then we must be
|
||||
# complete.
|
||||
|
||||
self.completed = True
|
||||
|
||||
break # prevent accidentally stepping through multiple constraints with just one token.
|
||||
|
||||
return complete, stepped
|
||||
|
||||
def copy(self, stateful=True):
|
||||
new_state = ConstraintListState(self.constraints) # we actually never though self.constraints objects
|
||||
# throughout this process. So it's at initialization state.
|
||||
|
||||
if stateful:
|
||||
new_state.complete_constraints = [
|
||||
constraint.copy(stateful=True) for constraint in self.complete_constraints
|
||||
]
|
||||
if self.inprogress_constraint is not None:
|
||||
new_state.inprogress_constraint = self.inprogress_constraint.copy(stateful=True)
|
||||
new_state.pending_constraints = [constraint.copy() for constraint in self.pending_constraints]
|
||||
|
||||
return new_state
|
||||
@@ -16,11 +16,13 @@
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .file_utils import add_start_docstrings
|
||||
from .generation_beam_constraints import Constraint, ConstraintListState
|
||||
|
||||
|
||||
PROCESS_INPUTS_DOCSTRING = r"""
|
||||
@@ -336,12 +338,462 @@ class BeamSearchScorer(BeamScorer):
|
||||
if sent_lengths.min().item() != sent_lengths.max().item():
|
||||
assert pad_token_id is not None, "`pad_token_id` has to be defined"
|
||||
decoded.fill_(pad_token_id)
|
||||
|
||||
# fill with hypotheses and eos_token_id if the latter fits in
|
||||
for i, hypo in enumerate(best):
|
||||
decoded[i, : sent_lengths[i]] = hypo
|
||||
if sent_lengths[i] < max_length:
|
||||
decoded[i, sent_lengths[i]] = eos_token_id
|
||||
|
||||
return UserDict(
|
||||
{
|
||||
"sequences": decoded,
|
||||
"sequence_scores": best_scores,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
r"""
|
||||
[`BeamScorer`] implementing constrained beam search decoding.
|
||||
|
||||
|
||||
Args:
|
||||
batch_size (`int`):
|
||||
Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
|
||||
max_length (`int`):
|
||||
The maximum length of the sequence to be generated.
|
||||
num_beams (`int`):
|
||||
Number of beams for beam search.
|
||||
constraints (`List[Constraint]`):
|
||||
A list of positive constraints represented as `Constraint` objects that must be fulfilled in the generation
|
||||
output. For more information, the documentation of [`Constraint`] should be read.
|
||||
device (`torch.device`):
|
||||
Defines the device type (*e.g.*, `"cpu"` or `"cuda"`) on which this instance of `BeamSearchScorer` will be
|
||||
allocated.
|
||||
length_penalty (`float`, *optional*, defaults to 1.0):
|
||||
Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the
|
||||
model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer
|
||||
sequences.
|
||||
do_early_stopping (`bool`, *optional*, defaults to `False`):
|
||||
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
|
||||
num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
|
||||
The number of beam hypotheses that shall be returned upon calling
|
||||
[`~transformer.BeamSearchScorer.finalize`].
|
||||
num_beam_groups (`int`):
|
||||
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
|
||||
See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_beams: int,
|
||||
constraints: List[Constraint],
|
||||
device: torch.device,
|
||||
length_penalty: Optional[float] = 1.0,
|
||||
do_early_stopping: Optional[bool] = False,
|
||||
num_beam_hyps_to_keep: Optional[int] = 1,
|
||||
num_beam_groups: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
self.num_beams = num_beams
|
||||
self.device = device
|
||||
self.length_penalty = length_penalty
|
||||
self.do_early_stopping = do_early_stopping
|
||||
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
|
||||
self.num_beam_groups = num_beam_groups
|
||||
self.group_size = self.num_beams // self.num_beam_groups
|
||||
self.constraints = constraints
|
||||
|
||||
self._is_init = False
|
||||
self._beam_hyps = [
|
||||
BeamHypotheses(
|
||||
num_beams=self.num_beams,
|
||||
length_penalty=self.length_penalty,
|
||||
early_stopping=self.do_early_stopping,
|
||||
)
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
|
||||
|
||||
if not isinstance(num_beams, int) or num_beams <= 1:
|
||||
raise ValueError(
|
||||
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead."
|
||||
)
|
||||
|
||||
if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
|
||||
raise ValueError(
|
||||
f"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` "
|
||||
f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
|
||||
)
|
||||
|
||||
if "max_length" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `max_length` to ConstrainedBeamSearchScorer is deprecated and has no effect. "
|
||||
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`"
|
||||
", or `group_beam_search(...)`."
|
||||
)
|
||||
|
||||
@property
|
||||
def is_done(self) -> bool:
|
||||
return self._done.all()
|
||||
|
||||
def make_constraint_states(self, n):
|
||||
return [ConstraintListState([constraint.copy() for constraint in self.constraints]) for _ in range(n)]
|
||||
|
||||
def check_completes_constraints(self, sequence):
|
||||
new_state = self.make_constraint_states(1)[0]
|
||||
new_state = new_state.reset(sequence)
|
||||
return new_state.completed
|
||||
|
||||
def process(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
next_scores: torch.FloatTensor,
|
||||
next_tokens: torch.LongTensor,
|
||||
next_indices: torch.LongTensor,
|
||||
scores_for_all_vocab: torch.FloatTensor,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
|
||||
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
next_scores (`torch.FloatTensor` of shape `(batch_size, 2 * num_beams)`):
|
||||
Current scores of the top `2 * num_beams` non-finished beam hypotheses.
|
||||
next_tokens (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
|
||||
`input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses.
|
||||
next_indices (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
|
||||
Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
|
||||
scores_for_all_vocab (`torch.FloatTensor` of shape `(batch_size * num_beams, sequence_length)`):
|
||||
The scores of all tokens in the vocabulary for each of the beam hypotheses.
|
||||
pad_token_id (`int`, *optional*):
|
||||
The id of the *padding* token.
|
||||
eos_token_id (`int`, *optional*):
|
||||
The id of the *end-of-sequence* token.
|
||||
|
||||
Return:
|
||||
`UserDict`: A dictionary composed of the fields as defined above:
|
||||
|
||||
- **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of
|
||||
all
|
||||
non-finished beams.
|
||||
- **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be
|
||||
added
|
||||
to the non-finished beam_hypotheses.
|
||||
- **next_beam_indices** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Beam indices
|
||||
indicating to which beam the next tokens shall be added.
|
||||
"""
|
||||
|
||||
cur_len = input_ids.shape[-1]
|
||||
batch_size = len(self._beam_hyps)
|
||||
if not (batch_size == (input_ids.shape[0] // self.group_size)):
|
||||
if self.num_beam_groups > 1:
|
||||
raise ValueError(
|
||||
f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
|
||||
f"size of {self.group_size} is expected by the beam scorer."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
|
||||
f"{self.group_size} is expected by the beam scorer."
|
||||
)
|
||||
|
||||
device = input_ids.device
|
||||
|
||||
next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
|
||||
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
|
||||
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
|
||||
|
||||
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
||||
if self._done[batch_idx]:
|
||||
if self.num_beams < len(beam_hyp):
|
||||
raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
|
||||
if eos_token_id is None or pad_token_id is None:
|
||||
raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
|
||||
# pad the batch
|
||||
next_beam_scores[batch_idx, :] = 0
|
||||
next_beam_tokens[batch_idx, :] = pad_token_id
|
||||
next_beam_indices[batch_idx, :] = 0
|
||||
continue
|
||||
|
||||
# next tokens for this sentence.
|
||||
beam_idx = 0
|
||||
for beam_token_rank, (next_token, next_score, next_index) in enumerate(
|
||||
zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
|
||||
):
|
||||
batch_beam_idx = batch_idx * self.group_size + next_index
|
||||
# add to generated hypotheses if end of sentence
|
||||
if (eos_token_id is not None) and (next_token.item() == eos_token_id):
|
||||
|
||||
# if beam_token does not belong to top num_beams tokens, it should not be added
|
||||
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
|
||||
if is_beam_token_worse_than_top_num_beams:
|
||||
continue
|
||||
|
||||
completes_constraint = self.check_completes_constraints(input_ids[batch_beam_idx])
|
||||
if completes_constraint:
|
||||
beam_hyp.add(
|
||||
input_ids[batch_beam_idx].clone(),
|
||||
next_score.item(),
|
||||
)
|
||||
else:
|
||||
# add next predicted token since it is not eos_token
|
||||
next_beam_scores[batch_idx, beam_idx] = next_score
|
||||
next_beam_tokens[batch_idx, beam_idx] = next_token
|
||||
next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
|
||||
beam_idx += 1
|
||||
|
||||
# once the beam for next step is full, don't add more tokens to it.
|
||||
if beam_idx == self.group_size:
|
||||
break
|
||||
|
||||
new_scores, new_tokens, new_indices = self.step_sentence_constraint(
|
||||
batch_idx,
|
||||
input_ids,
|
||||
scores_for_all_vocab,
|
||||
next_beam_scores[batch_idx],
|
||||
next_beam_tokens[batch_idx],
|
||||
next_beam_indices[batch_idx],
|
||||
)
|
||||
|
||||
next_beam_scores[batch_idx] = new_scores
|
||||
next_beam_tokens[batch_idx] = new_tokens
|
||||
next_beam_indices[batch_idx] = new_indices
|
||||
|
||||
if beam_idx < self.group_size:
|
||||
raise ValueError(
|
||||
f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
|
||||
)
|
||||
|
||||
# Check if we are done so that we can save a pad step if all(done)
|
||||
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
|
||||
next_scores[batch_idx].max().item(), cur_len
|
||||
)
|
||||
|
||||
return UserDict(
|
||||
{
|
||||
"next_beam_scores": next_beam_scores.view(-1),
|
||||
"next_beam_tokens": next_beam_tokens.view(-1),
|
||||
"next_beam_indices": next_beam_indices.view(-1),
|
||||
}
|
||||
)
|
||||
|
||||
def step_sentence_constraint(
|
||||
self,
|
||||
batch_idx: int,
|
||||
input_ids: torch.LongTensor,
|
||||
vocab_scores: torch.FloatTensor,
|
||||
sent_beam_scores: torch.FloatTensor,
|
||||
sent_beam_tokens: torch.LongTensor,
|
||||
sent_beam_indices: torch.LongTensor,
|
||||
push_progress: bool = False,
|
||||
):
|
||||
# sent_beam_tokens are the next {num_beams} number of tokens that are under consideration for this beam
|
||||
# (candidate next tokens)
|
||||
|
||||
# 1. Adding "advance_tokens"
|
||||
# using ConstraintStateList.advance(), we propose new tokens to be added into this "candidate list" that will
|
||||
# advance us in fulfilling the constraints.
|
||||
|
||||
# 2. Selecting best candidates such that we end up with highest probable candidates
|
||||
# that fulfill our constraints.
|
||||
|
||||
orig_len = sent_beam_indices.size(0)
|
||||
device = sent_beam_indices.device
|
||||
|
||||
# initialize states
|
||||
topk_contraint_states = self.make_constraint_states(orig_len)
|
||||
advance_constraint_states = self.make_constraint_states(orig_len)
|
||||
|
||||
sidx, eidx = batch_idx * orig_len, (batch_idx + 1) * orig_len
|
||||
this_batch_input_ids = input_ids[sidx:eidx]
|
||||
this_batch_token_scores = vocab_scores[sidx:eidx]
|
||||
full_hypotheses = torch.cat((input_ids[sent_beam_indices], sent_beam_tokens.unsqueeze(-1)), dim=-1)
|
||||
|
||||
# need to make new hypothesis that advance the constraints
|
||||
track_new = {"new_seqs": [], "new_states": [], "new_indices": [], "new_tokens": [], "new_scores": []}
|
||||
for seq_idx, pre_seq in enumerate(this_batch_input_ids):
|
||||
# pre_seq = ith sequence generated before this step.
|
||||
|
||||
# input_ids -> (topk) generic beam search best model next tokens
|
||||
# -> (advance) constraints forcing the next token
|
||||
# either way, we need to sort them into "banks" later, so store a "ConstraintListState" for all types of
|
||||
# hypotheses.
|
||||
|
||||
topk_state = topk_contraint_states[seq_idx]
|
||||
topk_state.reset(full_hypotheses[seq_idx])
|
||||
|
||||
advance_state = advance_constraint_states[seq_idx]
|
||||
advance_state.reset(pre_seq)
|
||||
|
||||
if not advance_state.completed:
|
||||
advance_tokens = advance_state.advance()
|
||||
for advance_token in advance_tokens.to(device):
|
||||
# since adding each `advance_token` leads to a different hypothesis, create new state instance.
|
||||
new_state = advance_state.copy(stateful=True)
|
||||
new_state.add(advance_token)
|
||||
|
||||
advance_seq = torch.cat((pre_seq, advance_token.unsqueeze(0)), -1).cpu().tolist()
|
||||
if advance_seq not in track_new["new_seqs"]:
|
||||
# prevent duplicates, which are basically bound to happen in this process.
|
||||
track_new["new_seqs"].append(advance_seq)
|
||||
track_new["new_indices"].append(seq_idx)
|
||||
track_new["new_tokens"].append(advance_token)
|
||||
track_new["new_scores"].append(this_batch_token_scores[seq_idx].take(advance_token))
|
||||
track_new["new_states"].append(new_state)
|
||||
elif push_progress:
|
||||
# Basically, `sent_beam_indices` often chooses very little among `input_ids` the generated sequences that
|
||||
# actually fulfill our constraints. For example, let constraints == ["loves pies"] and
|
||||
|
||||
# pre_seq_1 = "The child loves pies and" pre_seq_2 = "The child plays in the playground and"
|
||||
|
||||
# Without this step, if `sent_beam_indices` is something like [1,1], then
|
||||
# 1. `pre_seq_1` won't be added to the list of (topk) hypothesis since it's not in the indices and
|
||||
# 2. it won't be added to the list of (advance) hypothesis since it's completed already. (this is
|
||||
# the else part of `if constraints_completed[seq_idx]`)
|
||||
# 3. it ends up simply getting removed from consideration.
|
||||
|
||||
# #3 might be fine and actually desired, since it's likely that it's a low-probability output anyways,
|
||||
# especially if it's not in the list of `sent_beam_indices`. But this often leads to lengthened beam
|
||||
# search times, since completed sequences keep getting removed after all this effort for constrained
|
||||
# generation.
|
||||
|
||||
# Here, we basically take `pre_seq_1` and to "push" it into the considered list of hypotheses, by simply
|
||||
# appending the next likely token in the vocabulary and adding it to the list of hypotheses.
|
||||
|
||||
new_score, new_token = torch.max(this_batch_token_scores[seq_idx], 0) # some next probable token
|
||||
advance_seq = torch.cat((pre_seq, new_token.unsqueeze(0)), -1)
|
||||
|
||||
advance_state = advance_constraint_states[seq_idx]
|
||||
|
||||
advance_state.reset(advance_seq)
|
||||
advance_seq = advance_seq.cpu().tolist()
|
||||
if advance_seq not in track_new["new_seqs"]:
|
||||
# but still don't want to have duplicates
|
||||
track_new["new_seqs"].append(advance_seq)
|
||||
track_new["new_indices"].append(seq_idx)
|
||||
track_new["new_tokens"].append(new_token)
|
||||
track_new["new_scores"].append(new_score)
|
||||
track_new["new_states"].append(advance_state)
|
||||
|
||||
if len(track_new["new_indices"]) > 0:
|
||||
new_indices = torch.tensor(track_new["new_indices"]).to(device)
|
||||
new_tokens = torch.stack(track_new["new_tokens"]).to(device)
|
||||
new_scores = torch.stack(track_new["new_scores"]).to(device)
|
||||
|
||||
all_states = topk_contraint_states + track_new["new_states"]
|
||||
all_tokens = torch.cat((sent_beam_tokens, new_tokens), -1)
|
||||
all_scores = torch.cat((sent_beam_scores, new_scores), -1)
|
||||
all_banks = torch.tensor([one.get_bank() for one in all_states]).to(device)
|
||||
|
||||
zipped = all_banks * 100 + all_scores
|
||||
indices = zipped.sort(descending=True).indices
|
||||
sorted_banks = all_banks[indices]
|
||||
|
||||
# Then we end up with {sorted among bank C}, {sorted among bank C-1}, ..., {sorted among bank 0}
|
||||
|
||||
counter = -1
|
||||
cur_bank = sorted_banks[0]
|
||||
increments = []
|
||||
for bank in sorted_banks:
|
||||
if bank == cur_bank:
|
||||
counter += 1
|
||||
else:
|
||||
counter = 0
|
||||
cur_bank = bank
|
||||
increments.append(counter)
|
||||
rearrangers = torch.tensor(np.argsort(increments, kind="mergesort"))
|
||||
|
||||
indices = indices[rearrangers][:orig_len]
|
||||
|
||||
sent_beam_scores = all_scores[indices]
|
||||
sent_beam_tokens = all_tokens[indices]
|
||||
sent_beam_indices = torch.cat((sent_beam_indices, new_indices))[indices]
|
||||
|
||||
return sent_beam_scores, sent_beam_tokens, sent_beam_indices
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
final_beam_scores: torch.FloatTensor,
|
||||
final_beam_tokens: torch.LongTensor,
|
||||
final_beam_indices: torch.LongTensor,
|
||||
max_length: int,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
) -> Tuple[torch.LongTensor]:
|
||||
batch_size = len(self._beam_hyps)
|
||||
|
||||
# finalize all open beam hypotheses and add to generated hypotheses
|
||||
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
||||
if self._done[batch_idx]:
|
||||
continue
|
||||
|
||||
# all open beam hypotheses are added to the beam hypothesis
|
||||
# beam hypothesis class automatically keeps the best beams
|
||||
|
||||
ids_collect = []
|
||||
for beam_id in range(self.num_beams):
|
||||
batch_beam_idx = batch_idx * self.num_beams + beam_id
|
||||
final_score = final_beam_scores[batch_beam_idx].item()
|
||||
final_tokens = input_ids[batch_beam_idx]
|
||||
|
||||
completes_constraint = self.check_completes_constraints(final_tokens)
|
||||
if completes_constraint:
|
||||
beam_hyp.add(final_tokens, final_score)
|
||||
ids_collect.append(beam_id)
|
||||
|
||||
# due to overly complex constraints or other factors, sometimes we can't gaurantee a successful
|
||||
# generation. In these cases we simply return the highest scoring outputs.
|
||||
if len(ids_collect) < self.num_beam_hyps_to_keep:
|
||||
for beam_id in range(self.num_beams):
|
||||
if beam_id not in ids_collect:
|
||||
batch_beam_idx = batch_idx * self.num_beams + beam_id
|
||||
final_score = final_beam_scores[batch_beam_idx].item()
|
||||
final_tokens = input_ids[batch_beam_idx]
|
||||
beam_hyp.add(final_tokens, final_score)
|
||||
if len(ids_collect) >= self.num_beam_hyps_to_keep:
|
||||
break
|
||||
|
||||
# select the best hypotheses
|
||||
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
|
||||
best = []
|
||||
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
|
||||
|
||||
# retrieve best hypotheses
|
||||
for i, beam_hyp in enumerate(self._beam_hyps):
|
||||
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
|
||||
for j in range(self.num_beam_hyps_to_keep):
|
||||
best_hyp_tuple = sorted_hyps.pop()
|
||||
best_score = best_hyp_tuple[0]
|
||||
best_hyp = best_hyp_tuple[1]
|
||||
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
|
||||
|
||||
# append to lists
|
||||
best.append(best_hyp)
|
||||
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
|
||||
|
||||
# prepare for adding eos
|
||||
sent_lengths_max = sent_lengths.max().item() + 1
|
||||
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
|
||||
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
||||
# shorter batches are padded if needed
|
||||
if sent_lengths.min().item() != sent_lengths.max().item():
|
||||
assert pad_token_id is not None, "`pad_token_id` has to be defined"
|
||||
decoded.fill_(pad_token_id)
|
||||
|
||||
# fill with hypotheses and eos_token_id if the latter fits in
|
||||
for i, hypo in enumerate(best):
|
||||
decoded[i, : sent_lengths[i]] = hypo
|
||||
if sent_lengths[i] < sent_max_len:
|
||||
decoded[i, sent_lengths[i]] = eos_token_id
|
||||
return UserDict(
|
||||
{
|
||||
"sequences": decoded,
|
||||
|
||||
@@ -24,7 +24,8 @@ import torch.distributed as dist
|
||||
from torch import nn
|
||||
|
||||
from .file_utils import ModelOutput
|
||||
from .generation_beam_search import BeamScorer, BeamSearchScorer
|
||||
from .generation_beam_constraints import Constraint
|
||||
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||
from .generation_logits_process import (
|
||||
EncoderNoRepeatNGramLogitsProcessor,
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
@@ -839,6 +840,7 @@ class GenerationMixin:
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
|
||||
constraints: Optional[List[Constraint]] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
@@ -860,7 +862,6 @@ class GenerationMixin:
|
||||
post](https://huggingface.co/blog/how-to-generate).
|
||||
|
||||
Parameters:
|
||||
|
||||
inputs (`torch.Tensor` of shape `(batch_size, sequence_length)`, `(batch_size, sequence_length,
|
||||
feature_dim)` or `(batch_size, num_channels, height, width)`, *optional*):
|
||||
The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
|
||||
@@ -945,6 +946,9 @@ class GenerationMixin:
|
||||
Custom stopping criteria that complement the default stopping criteria built from arguments and a
|
||||
model's config. If a stopping criteria is passed that is already created with the arguments or a
|
||||
model's config an error is thrown. This feature is intended for advanced users.
|
||||
constraints (`List[Constraint]`, *optional*):
|
||||
Custom constraints that can be added to the generation to ensure that the output will contain the use
|
||||
of certain tokens as defined by `Constraint` objects, in the most sensible way possible.
|
||||
output_attentions (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more details.
|
||||
@@ -966,7 +970,6 @@ class GenerationMixin:
|
||||
crash. Note that using `remove_invalid_values` can slow down generation.
|
||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model
|
||||
is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
|
||||
@@ -1140,11 +1143,14 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
# 6. determine generation mode
|
||||
is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False
|
||||
is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True
|
||||
is_beam_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is False
|
||||
is_beam_sample_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is True
|
||||
is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1)
|
||||
is_constraint_gen_mode = constraints is not None
|
||||
is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False and constraints is None
|
||||
is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True and constraints is None
|
||||
is_beam_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is False and constraints is None
|
||||
is_beam_sample_gen_mode = (
|
||||
(num_beams > 1) and (num_beam_groups == 1) and do_sample is True and constraints is None
|
||||
)
|
||||
is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and constraints is None
|
||||
|
||||
if num_beam_groups > num_beams:
|
||||
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
|
||||
@@ -1339,6 +1345,50 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
elif is_constraint_gen_mode:
|
||||
if num_return_sequences > num_beams:
|
||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||
|
||||
if stopping_criteria.max_length is None:
|
||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||
|
||||
if num_beams <= 1:
|
||||
raise ValueError("`num_beams` needs to be greater than 1 for constrained genertation.")
|
||||
|
||||
if do_sample:
|
||||
raise ValueError("`do_sample` needs to be false for constrained generation.")
|
||||
|
||||
if num_beam_groups is not None and num_beam_groups > 1:
|
||||
raise ValueError("`num_beam_groups` not supported yet for constrained generation.")
|
||||
|
||||
# 10. prepare beam search scorer
|
||||
constrained_beam_scorer = ConstrainedBeamSearchScorer(
|
||||
constraints=constraints,
|
||||
batch_size=batch_size,
|
||||
num_beams=num_beams,
|
||||
device=self.device,
|
||||
length_penalty=length_penalty,
|
||||
do_early_stopping=early_stopping,
|
||||
num_beam_hyps_to_keep=num_return_sequences,
|
||||
)
|
||||
# 11. interleave input_ids with `num_beams` additional sequences per batch
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
|
||||
)
|
||||
# 12. run beam search
|
||||
return self.constrained_beam_search(
|
||||
input_ids,
|
||||
constrained_beam_scorer=constrained_beam_scorer,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
output_scores=output_scores,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
synced_gpus=synced_gpus,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
def greedy_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
@@ -2800,6 +2850,318 @@ class GenerationMixin:
|
||||
else:
|
||||
return sequence_outputs["sequences"]
|
||||
|
||||
def constrained_beam_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
constrained_beam_scorer: ConstrainedBeamSearchScorer,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
synced_gpus: Optional[bool] = None,
|
||||
**model_kwargs,
|
||||
) -> Union[BeamSearchOutput, torch.LongTensor]:
|
||||
|
||||
r"""
|
||||
Generates sequences for models with a language modeling head using beam search decoding.
|
||||
|
||||
Parameters:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
The sequence used as a prompt for the generation.
|
||||
constrained_beam_scorer (`ConstrainedBeamSearchScorer`):
|
||||
A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
|
||||
sorted during generation, while satisfying a list of positive constraints. For more information, the
|
||||
documentation of [`ConstrainedBeamSearchScorer`] should be read.
|
||||
logits_processor (`LogitsProcessorList`, *optional*):
|
||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
||||
used to modify the prediction scores of the language modeling head applied at each generation step.
|
||||
stopping_criteria (`StoppingCriteriaList`, *optional*):
|
||||
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
||||
used to tell if the generation loop should stop.
|
||||
logits_warper (`LogitsProcessorList`, *optional*):
|
||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
|
||||
to warp the prediction score distribution of the language modeling head applied before multinomial
|
||||
sampling at each generation step.
|
||||
max_length (`int`, *optional*, defaults to 20):
|
||||
**DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
|
||||
tokens. The maximum length of the sequence to be generated.
|
||||
pad_token_id (`int`, *optional*):
|
||||
The id of the *padding* token.
|
||||
eos_token_id (`int`, *optional*):
|
||||
The id of the *end-of-sequence* token.
|
||||
output_attentions (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more details.
|
||||
output_hidden_states (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||
for more details.
|
||||
output_scores (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
|
||||
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
||||
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||
|
||||
Return:
|
||||
[`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or
|
||||
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
|
||||
[`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
|
||||
`return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if
|
||||
`model.config.is_encoder_decoder=True`.
|
||||
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import (
|
||||
... AutoTokenizer,
|
||||
... AutoModelForSeq2SeqLM,
|
||||
... LogitsProcessorList,
|
||||
... MinLengthLogitsProcessor,
|
||||
... ConstrainedBeamSearchScorer,
|
||||
... PhrasalConstraint,
|
||||
... )
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
||||
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
||||
|
||||
>>> encoder_input_str = "translate English to German: How old are you?"
|
||||
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
|
||||
|
||||
|
||||
>>> # lets run beam search using 3 beams
|
||||
>>> num_beams = 3
|
||||
>>> # define decoder start token ids
|
||||
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
|
||||
>>> input_ids = input_ids * model.config.decoder_start_token_id
|
||||
|
||||
>>> # add encoder_outputs to model keyword arguments
|
||||
>>> model_kwargs = {
|
||||
... "encoder_outputs": model.get_encoder()(
|
||||
... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
|
||||
... )
|
||||
... }
|
||||
|
||||
>>> constraint_str = "sind"
|
||||
>>> constraint_token_ids = tokenizer.encode(constraint_str)[:-1] # slice to remove eos token
|
||||
>>> constraints = [PhrasalConstraint(token_ids=constraint_token_ids)]
|
||||
|
||||
|
||||
>>> # instantiate beam scorer
|
||||
>>> beam_scorer = ConstrainedBeamSearchScorer(
|
||||
... batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints
|
||||
... )
|
||||
|
||||
>>> # instantiate logits processors
|
||||
>>> logits_processor = LogitsProcessorList(
|
||||
... [
|
||||
... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
|
||||
... ]
|
||||
... )
|
||||
|
||||
>>> outputs = model.constrained_beam_search(
|
||||
... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
|
||||
... )
|
||||
|
||||
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
# => ['Wie alter sind Sie?']
|
||||
```"""
|
||||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
if max_length is not None:
|
||||
warnings.warn(
|
||||
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
|
||||
UserWarning,
|
||||
)
|
||||
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||
if len(stopping_criteria) == 0:
|
||||
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict_in_generate = (
|
||||
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
||||
)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||
|
||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||
if return_dict_in_generate and self.config.is_encoder_decoder:
|
||||
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
||||
encoder_hidden_states = (
|
||||
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
||||
)
|
||||
|
||||
batch_size = len(constrained_beam_scorer._beam_hyps)
|
||||
num_beams = constrained_beam_scorer.num_beams
|
||||
|
||||
batch_beam_size, cur_len = input_ids.shape
|
||||
|
||||
if num_beams * batch_size != batch_beam_size:
|
||||
raise ValueError(
|
||||
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
||||
)
|
||||
|
||||
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
||||
beam_scores[:, 1:] = -1e9
|
||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
while True:
|
||||
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
# The following logic allows an early break if all peers finished generating their sequence
|
||||
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
||||
# send 0.0 if we finished, 1.0 otherwise
|
||||
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
||||
# did all peers finish? the reduced sum will be 0.0 then
|
||||
if this_peer_finished_flag.item() == 0.0:
|
||||
break
|
||||
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
if synced_gpus and this_peer_finished:
|
||||
cur_len = cur_len + 1
|
||||
continue # don't waste resources running the code we don't need
|
||||
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
|
||||
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
||||
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
||||
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
|
||||
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
|
||||
next_token_scores = nn.functional.log_softmax(
|
||||
next_token_logits, dim=-1
|
||||
) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
|
||||
|
||||
scores_for_all_vocab = next_token_scores_processed.clone()
|
||||
|
||||
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
if return_dict_in_generate:
|
||||
if output_scores:
|
||||
scores += (next_token_scores,)
|
||||
if output_attentions:
|
||||
decoder_attentions += (
|
||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||
)
|
||||
if self.config.is_encoder_decoder:
|
||||
cross_attentions += (outputs.cross_attentions,)
|
||||
|
||||
if output_hidden_states:
|
||||
decoder_hidden_states += (
|
||||
(outputs.decoder_hidden_states,)
|
||||
if self.config.is_encoder_decoder
|
||||
else (outputs.hidden_states,)
|
||||
)
|
||||
|
||||
# reshape for beam search
|
||||
vocab_size = next_token_scores.shape[-1]
|
||||
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
|
||||
|
||||
next_token_scores, next_tokens = torch.topk(
|
||||
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
|
||||
)
|
||||
|
||||
next_indices = (next_tokens / vocab_size).long()
|
||||
next_tokens = next_tokens % vocab_size
|
||||
|
||||
# stateless
|
||||
beam_outputs = constrained_beam_scorer.process(
|
||||
input_ids,
|
||||
next_token_scores,
|
||||
next_tokens,
|
||||
next_indices,
|
||||
scores_for_all_vocab,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
)
|
||||
beam_scores = beam_outputs["next_beam_scores"]
|
||||
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
||||
beam_idx = beam_outputs["next_beam_indices"]
|
||||
|
||||
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
if model_kwargs["past"] is not None:
|
||||
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
|
||||
|
||||
# increase cur_len
|
||||
cur_len = cur_len + 1
|
||||
|
||||
if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores):
|
||||
if not synced_gpus:
|
||||
break
|
||||
else:
|
||||
this_peer_finished = True
|
||||
|
||||
sequence_outputs = constrained_beam_scorer.finalize(
|
||||
input_ids,
|
||||
beam_scores,
|
||||
next_tokens,
|
||||
next_indices,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
max_length=stopping_criteria.max_length,
|
||||
)
|
||||
|
||||
if return_dict_in_generate:
|
||||
if not output_scores:
|
||||
sequence_outputs["sequence_scores"] = None
|
||||
if self.config.is_encoder_decoder:
|
||||
return BeamSearchEncoderDecoderOutput(
|
||||
sequences=sequence_outputs["sequences"],
|
||||
sequences_scores=sequence_outputs["sequence_scores"],
|
||||
scores=scores,
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
decoder_attentions=decoder_attentions,
|
||||
cross_attentions=cross_attentions,
|
||||
decoder_hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
return BeamSearchDecoderOnlyOutput(
|
||||
sequences=sequence_outputs["sequences"],
|
||||
sequences_scores=sequence_outputs["sequence_scores"],
|
||||
scores=scores,
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
return sequence_outputs["sequences"]
|
||||
|
||||
|
||||
def top_k_top_p_filtering(
|
||||
logits: torch.FloatTensor,
|
||||
|
||||
@@ -80,6 +80,27 @@ class TextDatasetForNextSentencePrediction(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Constraint(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class ConstraintListState(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class PhrasalConstraint(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class BeamScorer(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -94,6 +115,13 @@ class BeamSearchScorer(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class ConstrainedBeamSearchScorer(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class ForcedBOSTokenLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -25,7 +25,8 @@ from .test_modeling_common import floats_tensor, ids_tensor
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.generation_beam_search import BeamHypotheses, BeamSearchScorer
|
||||
from transformers.generation_beam_constraints import PhrasalConstraint
|
||||
from transformers.generation_beam_search import BeamHypotheses, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||
|
||||
|
||||
class BeamSearchTester:
|
||||
@@ -232,6 +233,270 @@ class BeamSearchTester:
|
||||
self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size])
|
||||
|
||||
|
||||
class ConstrainedBeamSearchTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
constraints=None,
|
||||
batch_size=3,
|
||||
sequence_length=10,
|
||||
vocab_size=99,
|
||||
pad_token_id=0,
|
||||
max_length=20,
|
||||
num_beams=4,
|
||||
length_penalty=2.0,
|
||||
do_early_stopping=True,
|
||||
num_beam_hyps_to_keep=2,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.sequence_length = sequence_length
|
||||
self.vocab_size = vocab_size
|
||||
self.pad_token_id = pad_token_id
|
||||
self.max_length = max_length
|
||||
self.num_beams = num_beams
|
||||
self.length_penalty = length_penalty
|
||||
self.do_early_stopping = do_early_stopping
|
||||
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
|
||||
|
||||
if constraints is None:
|
||||
force_tokens = torch.randint(10, 50, (1, 2)).type(torch.LongTensor)[0]
|
||||
constraints = [
|
||||
PhrasalConstraint(force_tokens),
|
||||
]
|
||||
self.constraints = constraints
|
||||
# cannot be randomely generated
|
||||
self.eos_token_id = vocab_size + 1
|
||||
|
||||
def prepare_constrained_beam_scorer(self, **kwargs):
|
||||
return ConstrainedBeamSearchScorer(
|
||||
constraints=kwargs.get("constraints", self.constraints),
|
||||
batch_size=kwargs.get("batch_size", self.batch_size),
|
||||
num_beams=kwargs.get("num_beams", self.num_beams),
|
||||
device=torch_device,
|
||||
length_penalty=kwargs.get("length_penalty", self.length_penalty),
|
||||
do_early_stopping=kwargs.get("do_early_stopping", self.do_early_stopping),
|
||||
num_beam_hyps_to_keep=kwargs.get("num_beam_hyps_to_keep", self.num_beam_hyps_to_keep),
|
||||
)
|
||||
|
||||
def prepare_inputs(self):
|
||||
input_ids = ids_tensor((self.batch_size * self.num_beams, self.sequence_length), self.vocab_size)
|
||||
next_tokens = ids_tensor((self.batch_size, 2 * self.num_beams), self.vocab_size).to(torch_device)
|
||||
next_indices = ids_tensor((self.batch_size, 2 * self.num_beams), self.num_beams).to(torch_device)
|
||||
next_scores, _ = (-floats_tensor((self.batch_size, 2 * self.num_beams)).to(torch_device)).sort(descending=True)
|
||||
scores_for_all_vocab, _ = (
|
||||
-floats_tensor((self.batch_size * self.num_beams, self.vocab_size)).to(torch_device)
|
||||
).sort(descending=True)
|
||||
return (input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab)
|
||||
|
||||
def check_beam_hypotheses(self, input_ids, *args):
|
||||
# check that correct number of beam hypotheses is set in beam scorer
|
||||
constrained_beam_scorer = self.prepare_constrained_beam_scorer(do_early_stopping=True)
|
||||
beam_hyp = constrained_beam_scorer._beam_hyps[0]
|
||||
|
||||
self.parent.assertEqual(len(constrained_beam_scorer._beam_hyps), self.batch_size)
|
||||
|
||||
# check correct type
|
||||
self.parent.assertTrue(isinstance(beam_hyp, BeamHypotheses))
|
||||
|
||||
# check that num_beams is correctly set
|
||||
self.parent.assertEqual(beam_hyp.num_beams, self.num_beams)
|
||||
|
||||
# check for early stopping deactivated
|
||||
for beam_idx in range(self.num_beams):
|
||||
beam_hyp.add(input_ids[beam_idx], -10.0)
|
||||
|
||||
# if early stopping True -> score does not matter
|
||||
self.parent.assertTrue(beam_hyp.is_done(-10.0, 5))
|
||||
|
||||
# re-init
|
||||
constrained_beam_scorer = self.prepare_constrained_beam_scorer(do_early_stopping=False)
|
||||
beam_hyp = constrained_beam_scorer._beam_hyps[0]
|
||||
|
||||
# add `num_beams + 1` beams to change `worst_score`
|
||||
for beam_idx in range(self.num_beams + 1):
|
||||
beam_hyp.add(input_ids[beam_idx], -10.0 + float(beam_idx))
|
||||
|
||||
# -10.0 is removed => -9.0 is worst score
|
||||
self.parent.assertAlmostEqual(beam_hyp.worst_score, -9.0 / (self.sequence_length ** beam_hyp.length_penalty))
|
||||
|
||||
# -5.0 is better than worst score => should not be finished
|
||||
self.parent.assertFalse(beam_hyp.is_done(-5.0, self.sequence_length))
|
||||
|
||||
# -20.0 is worse than worst score => should be finished
|
||||
self.parent.assertTrue(beam_hyp.is_done(-20.0, self.sequence_length))
|
||||
|
||||
def check_constrained_beam_scorer_update(
|
||||
self, input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab
|
||||
):
|
||||
# check too many eos tokens
|
||||
constrained_beam_scorer = self.prepare_constrained_beam_scorer()
|
||||
fulfilling_sequence = torch.stack([constraint.token_ids for constraint in self.constraints]).flatten()
|
||||
fulfill_len = fulfilling_sequence.size(0)
|
||||
input_ids[:, :fulfill_len] = fulfilling_sequence
|
||||
|
||||
tokens = next_tokens.clone()
|
||||
tokens[0, :] = self.eos_token_id
|
||||
|
||||
with self.parent.assertRaises(ValueError):
|
||||
constrained_beam_scorer.process(
|
||||
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
|
||||
)
|
||||
|
||||
# check all batches are done
|
||||
constrained_beam_scorer = self.prepare_constrained_beam_scorer()
|
||||
|
||||
tokens = next_tokens.clone()
|
||||
tokens[:, : self.num_beams] = self.eos_token_id
|
||||
constrained_beam_scorer.process(
|
||||
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
|
||||
)
|
||||
# beam scorer should be done
|
||||
self.parent.assertTrue(constrained_beam_scorer.is_done)
|
||||
|
||||
# check
|
||||
constrained_beam_scorer = self.prepare_constrained_beam_scorer()
|
||||
|
||||
tokens = next_tokens.clone()
|
||||
tokens[:, 1] = self.eos_token_id
|
||||
beam_outputs = constrained_beam_scorer.process(
|
||||
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
|
||||
)
|
||||
output_scores = beam_outputs["next_beam_scores"]
|
||||
output_tokens = beam_outputs["next_beam_tokens"]
|
||||
output_indices = beam_outputs["next_beam_indices"]
|
||||
|
||||
def cut_expected_tensor(tensor):
|
||||
return torch.cat([tensor[:, :1], tensor[:, 2 : self.num_beams + 1]], dim=1).flatten()
|
||||
|
||||
# check all outptus
|
||||
# cut out id of eos token and take best `num_beams` outputs
|
||||
expected_output_tokens = cut_expected_tensor(tokens)
|
||||
expected_output_scores = cut_expected_tensor(next_scores)
|
||||
|
||||
# add num_beams * batch_idx
|
||||
expected_output_indices = (
|
||||
cut_expected_tensor(next_indices)
|
||||
+ (torch.arange(self.num_beams * self.batch_size, device=torch_device) // self.num_beams) * self.num_beams
|
||||
)
|
||||
|
||||
self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist())
|
||||
self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist())
|
||||
self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3))
|
||||
|
||||
# make sure ids of eos token are correctly saved in beam_hyps of beam scorer
|
||||
for batch_idx in range(self.batch_size):
|
||||
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
|
||||
self.parent.assertListEqual(
|
||||
input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist()
|
||||
)
|
||||
|
||||
def check_constrained_beam_scorer_finalize(
|
||||
self, input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab
|
||||
):
|
||||
# max_length should be only one more than current input_ids to check that eos is correctly appended
|
||||
max_length = self.sequence_length + 1
|
||||
|
||||
# for testing finalize, we do want to have fulfilled constraints
|
||||
fulfilling_sequence = torch.stack([constraint.token_ids for constraint in self.constraints]).flatten()
|
||||
fulfill_len = fulfilling_sequence.size(0)
|
||||
input_ids[:, :fulfill_len] = fulfilling_sequence
|
||||
|
||||
constrained_beam_scorer = self.prepare_constrained_beam_scorer(
|
||||
num_beam_hyps_to_keep=1, length_penalty=1.0, do_early_stopping=False
|
||||
)
|
||||
|
||||
constraints = constrained_beam_scorer.constraints
|
||||
# update beams and append to input_ids
|
||||
tokens = next_tokens.clone()
|
||||
# first batch, first output has to finish with eos token id since scores are correctly sorted
|
||||
tokens[0, 0] = self.eos_token_id
|
||||
# make sure corresponding score is as good as possible to surely be picked first
|
||||
next_scores[0, 0] = 0.0
|
||||
|
||||
beam_outputs = constrained_beam_scorer.process(
|
||||
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
|
||||
)
|
||||
output_scores = beam_outputs["next_beam_scores"]
|
||||
output_tokens = beam_outputs["next_beam_tokens"]
|
||||
output_indices = beam_outputs["next_beam_indices"]
|
||||
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1)
|
||||
|
||||
# finalize
|
||||
sequence_output = constrained_beam_scorer.finalize(
|
||||
input_ids,
|
||||
output_scores,
|
||||
output_tokens,
|
||||
output_indices,
|
||||
pad_token_id=self.pad_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
sequences = sequence_output["sequences"]
|
||||
sequence_scores = sequence_output["sequence_scores"]
|
||||
|
||||
# since `num_beam_hyps_to_keep` = 1 => only return `batch_size` x `max_length`
|
||||
self.parent.assertListEqual(list(sequences.shape), [self.batch_size, max_length])
|
||||
self.parent.assertListEqual(list(sequence_scores.shape), [self.batch_size])
|
||||
|
||||
# check sequence_scores
|
||||
self.parent.assertFalse((sequence_scores > 0).any().item())
|
||||
|
||||
# first batch has to finish with eos_token
|
||||
self.parent.assertEqual(sequences[0, -1].item(), self.eos_token_id)
|
||||
|
||||
# other batches cannot finish with eos token
|
||||
self.parent.assertNotEqual(sequences[1, -1].item(), self.eos_token_id)
|
||||
self.parent.assertNotEqual(sequences[2, -1].item(), self.eos_token_id)
|
||||
|
||||
# test that the constraint is indeed fulfilled
|
||||
for output in sequences:
|
||||
for constraint in constraints:
|
||||
forced_token_ids = constraint.token_ids
|
||||
self.parent.assertEqual(self._check_sequence_inside_sequence(output, forced_token_ids), True)
|
||||
|
||||
# now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned
|
||||
|
||||
# constrained_beam_scorer.num_beam_hyps_to_keep = self.num_beams
|
||||
constrained_beam_scorer = self.prepare_constrained_beam_scorer(
|
||||
num_beam_hyps_to_keep=self.num_beams, length_penalty=1.0, do_early_stopping=False
|
||||
)
|
||||
|
||||
sequence_output = constrained_beam_scorer.finalize(
|
||||
input_ids,
|
||||
output_scores,
|
||||
output_tokens,
|
||||
output_indices,
|
||||
pad_token_id=self.pad_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
max_length=max_length,
|
||||
)
|
||||
sequences = sequence_output["sequences"]
|
||||
sequence_scores = sequence_output["sequence_scores"]
|
||||
|
||||
self.parent.assertListEqual(list(sequences.shape), [self.num_beams * self.batch_size, max_length])
|
||||
self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size])
|
||||
|
||||
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
|
||||
# set to same device. we don't care what device.
|
||||
tensor_1, tensor_2 = tensor_1.cpu(), tensor_2.cpu()
|
||||
|
||||
in_order = tensor_1.size(0) <= tensor_2.size(0)
|
||||
longer = tensor_2 if in_order else tensor_1
|
||||
shorter = tensor_1 if in_order else tensor_2
|
||||
|
||||
flag = False
|
||||
chunk_size = shorter.size(0)
|
||||
for chunk_idx in range(longer.size(0) - chunk_size + 1):
|
||||
subseq = longer[chunk_idx : chunk_idx + chunk_size]
|
||||
if torch.equal(subseq, shorter):
|
||||
flag = True
|
||||
break
|
||||
|
||||
return flag
|
||||
|
||||
|
||||
@require_torch
|
||||
class BeamSearchTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@@ -248,3 +513,21 @@ class BeamSearchTest(unittest.TestCase):
|
||||
def test_beam_scorer_finalize(self):
|
||||
inputs = self.beam_search_tester.prepare_inputs()
|
||||
self.beam_search_tester.check_beam_scores_finalize(*inputs)
|
||||
|
||||
|
||||
@require_torch
|
||||
class ConstrainedBeamSearchTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.constrained_beam_search_tester = ConstrainedBeamSearchTester(self)
|
||||
|
||||
def test_constrained_beam_hypotheses(self):
|
||||
inputs = self.constrained_beam_search_tester.prepare_inputs()
|
||||
self.constrained_beam_search_tester.check_beam_hypotheses(*inputs)
|
||||
|
||||
def test_constrained_beam_scorer_update(self):
|
||||
inputs = self.constrained_beam_search_tester.prepare_inputs()
|
||||
self.constrained_beam_search_tester.check_constrained_beam_scorer_update(*inputs)
|
||||
|
||||
def test_constrained_beam_scorer_finalize(self):
|
||||
inputs = self.constrained_beam_search_tester.prepare_inputs()
|
||||
self.constrained_beam_search_tester.check_constrained_beam_scorer_finalize(*inputs)
|
||||
|
||||
@@ -27,6 +27,8 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
BartForConditionalGeneration,
|
||||
BartTokenizer,
|
||||
GPT2LMHeadModel,
|
||||
@@ -37,7 +39,8 @@ if is_torch_available():
|
||||
VisionEncoderDecoderModel,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
from transformers.generation_beam_search import BeamSearchScorer
|
||||
from transformers.generation_beam_constraints import PhrasalConstraint
|
||||
from transformers.generation_beam_search import BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||
from transformers.generation_logits_process import (
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
@@ -190,6 +193,25 @@ class GenerationTesterMixin:
|
||||
)
|
||||
return beam_kwargs, beam_scorer
|
||||
|
||||
@staticmethod
|
||||
def _get_constrained_beam_scorer_and_kwargs(batch_size, max_length, constraints, num_return_sequences=1):
|
||||
beam_kwargs = {
|
||||
"early_stopping": False,
|
||||
"length_penalty": 2.0,
|
||||
"num_beams": num_return_sequences * 4,
|
||||
"num_return_sequences": num_return_sequences,
|
||||
}
|
||||
beam_scorer = ConstrainedBeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
constraints=constraints,
|
||||
num_beams=beam_kwargs["num_beams"],
|
||||
device=torch_device,
|
||||
length_penalty=beam_kwargs["length_penalty"],
|
||||
do_early_stopping=beam_kwargs["early_stopping"],
|
||||
num_beam_hyps_to_keep=num_return_sequences,
|
||||
)
|
||||
return beam_kwargs, beam_scorer
|
||||
|
||||
@staticmethod
|
||||
def _get_encoder_outputs(
|
||||
model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1
|
||||
@@ -526,6 +548,69 @@ class GenerationTesterMixin:
|
||||
)
|
||||
return output_generate, output_group_beam_search
|
||||
|
||||
def _constrained_beam_search_generate(
|
||||
self,
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
max_length,
|
||||
constrained_beam_scorer,
|
||||
constraints,
|
||||
beam_kwargs,
|
||||
logits_processor,
|
||||
logits_process_kwargs,
|
||||
output_scores=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=False,
|
||||
max_length=max_length,
|
||||
output_scores=output_scores,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
constraints=constraints,
|
||||
**beam_kwargs,
|
||||
**logits_process_kwargs,
|
||||
)
|
||||
|
||||
# group_beam_search does not automatically interleave `batch_size` dim for `num_beams`
|
||||
kwargs = {}
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
num_interleave=constrained_beam_scorer.num_beams,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
input_ids_clone = input_ids_clone.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
|
||||
else:
|
||||
attention_mask_clone = attention_mask.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
|
||||
input_ids_clone = input_ids.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
|
||||
|
||||
with torch.no_grad():
|
||||
output_group_beam_search = model.constrained_beam_search(
|
||||
input_ids_clone,
|
||||
constrained_beam_scorer,
|
||||
max_length=max_length,
|
||||
attention_mask=attention_mask_clone,
|
||||
logits_processor=logits_processor,
|
||||
output_scores=output_scores,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**kwargs,
|
||||
)
|
||||
return output_generate, output_group_beam_search
|
||||
|
||||
def test_greedy_generate(self):
|
||||
# check `generate()` and `greedy_search()` are equal
|
||||
for model_class in self.all_generative_model_classes:
|
||||
@@ -719,6 +804,7 @@ class GenerationTesterMixin:
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
self.assertListEqual(output_generate.tolist(), output_beam_search.tolist())
|
||||
|
||||
# check `generate()` and `beam_search()` are equal for `num_return_sequences`
|
||||
@@ -1085,6 +1171,164 @@ class GenerationTesterMixin:
|
||||
output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams
|
||||
)
|
||||
|
||||
def test_constrained_beam_search_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# It is important set set the eos_token_id to None to ensure that no sequences
|
||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||
config.eos_token_id = None
|
||||
config.forced_eos_token_id = None
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
max_length = 20
|
||||
|
||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.eos_token_id,
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
max_length,
|
||||
)
|
||||
|
||||
# check `generate()` and `constrained_beam_search()` are equal
|
||||
# Sample constraints
|
||||
if not input_ids.dtype == torch.float32:
|
||||
min_id = torch.min(input_ids) + 3
|
||||
max_id = torch.max(input_ids)
|
||||
else:
|
||||
# otherwise this throws an error for Speech2TextModel since its inputs are floating points
|
||||
min_id = 3
|
||||
max_id = 100
|
||||
|
||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0]
|
||||
constraints = [
|
||||
PhrasalConstraint(force_tokens),
|
||||
]
|
||||
|
||||
beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs(
|
||||
input_ids.shape[0], max_length, constraints, num_return_sequences=1
|
||||
)
|
||||
output_generate, output_beam_search = self._constrained_beam_search_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
constrained_beam_scorer=beam_scorer,
|
||||
constraints=constraints,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_processor=logits_processor,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
)
|
||||
self.assertListEqual(output_generate.tolist(), output_beam_search.tolist())
|
||||
for generation_output in output_generate:
|
||||
self._check_sequence_inside_sequence(force_tokens, generation_output)
|
||||
|
||||
# check `generate()` and `constrained_beam_search()` are equal for `num_return_sequences`
|
||||
# Sample constraints
|
||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0]
|
||||
constraints = [
|
||||
PhrasalConstraint(force_tokens),
|
||||
]
|
||||
|
||||
num_return_sequences = 2
|
||||
max_length = 20
|
||||
|
||||
beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs(
|
||||
input_ids.shape[0], max_length, constraints, num_return_sequences=num_return_sequences
|
||||
)
|
||||
|
||||
output_generate, output_beam_search = self._constrained_beam_search_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
constrained_beam_scorer=beam_scorer,
|
||||
constraints=constraints,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_processor=logits_processor,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
)
|
||||
self.assertListEqual(output_generate.tolist(), output_beam_search.tolist())
|
||||
|
||||
for generation_output in output_generate:
|
||||
self._check_sequence_inside_sequence(force_tokens, generation_output)
|
||||
|
||||
def test_constrained_beam_search_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# disable cache
|
||||
config.use_cache = False
|
||||
|
||||
# It is important set set the eos_token_id to None to ensure that no sequences
|
||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||
config.eos_token_id = None
|
||||
config.forced_eos_token_id = None
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 20
|
||||
|
||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.eos_token_id,
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
max_length,
|
||||
)
|
||||
|
||||
# Sample constraints
|
||||
if not input_ids.dtype == torch.float32:
|
||||
min_id = torch.min(input_ids) + 3
|
||||
max_id = torch.max(input_ids)
|
||||
else:
|
||||
# otherwise this throws an error for Speech2TextModel since its inputs are floating points
|
||||
min_id = 3
|
||||
max_id = 100
|
||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0]
|
||||
constraints = [
|
||||
PhrasalConstraint(force_tokens),
|
||||
]
|
||||
|
||||
beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs(
|
||||
input_ids.shape[0], max_length, constraints, num_return_sequences=1
|
||||
)
|
||||
output_generate, output_beam_search = self._constrained_beam_search_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
constrained_beam_scorer=beam_scorer,
|
||||
constraints=constraints,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_processor=logits_processor,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||
|
||||
self.assertListEqual(output_generate.sequences.tolist(), output_beam_search.sequences.tolist())
|
||||
self.assertTrue(
|
||||
torch.allclose(output_generate["sequences_scores"], output_beam_search["sequences_scores"], atol=1e-3)
|
||||
)
|
||||
self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],))
|
||||
self.assertTrue((output_generate["sequences_scores"] < 0).all().item())
|
||||
|
||||
for output in (output_beam_search, output_generate):
|
||||
self._check_outputs(output, input_ids, model.config, num_return_sequences=beam_scorer.num_beams)
|
||||
|
||||
def test_generate_with_head_masking(self):
|
||||
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
|
||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||
@@ -1254,6 +1498,24 @@ class GenerationTesterMixin:
|
||||
[encoder_expected_shape] * len(hidden_states),
|
||||
)
|
||||
|
||||
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
|
||||
# set to same device. we don't care what device.
|
||||
tensor_1, tensor_2 = tensor_1.cpu(), tensor_2.cpu()
|
||||
|
||||
in_order = tensor_1.size(0) <= tensor_2.size(0)
|
||||
longer = tensor_2 if in_order else tensor_1
|
||||
shorter = tensor_1 if in_order else tensor_2
|
||||
|
||||
flag = False
|
||||
chunk_size = shorter.size(0)
|
||||
for chunk_idx in range(longer.size(0) - chunk_size + 1):
|
||||
subseq = longer[chunk_idx : chunk_idx + chunk_size]
|
||||
if torch.equal(subseq, shorter):
|
||||
flag = True
|
||||
break
|
||||
|
||||
self.assertTrue(flag)
|
||||
|
||||
|
||||
@require_torch
|
||||
class UtilsFunctionsTest(unittest.TestCase):
|
||||
@@ -2047,3 +2309,83 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
transition_scores_sum = transition_scores.sum(-1)
|
||||
|
||||
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
|
||||
force_tokens = tokenizer.encode(" scared", return_tensors="pt").to(torch_device)[0]
|
||||
force_tokens_2 = tokenizer.encode(" big weapons", return_tensors="pt").to(torch_device)[0]
|
||||
|
||||
constraints = [
|
||||
PhrasalConstraint(force_tokens),
|
||||
PhrasalConstraint(force_tokens_2),
|
||||
]
|
||||
|
||||
starting_text = ["The soldiers were not prepared and"]
|
||||
|
||||
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids,
|
||||
constraints=constraints,
|
||||
num_beams=10,
|
||||
num_return_sequences=1,
|
||||
no_repeat_ngram_size=1,
|
||||
max_length=30,
|
||||
remove_invalid_values=True,
|
||||
)
|
||||
|
||||
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(
|
||||
generated_text,
|
||||
[
|
||||
"The soldiers were not prepared and didn't know how big the big weapons would be, so they scared them off. They had no idea what to do",
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search_example_integration(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
||||
|
||||
encoder_input_str = "translate English to German: How old are you?"
|
||||
encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
|
||||
|
||||
# lets run beam search using 5 beams
|
||||
num_beams = 5
|
||||
# define decoder start token ids
|
||||
input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
|
||||
input_ids = input_ids * model.config.decoder_start_token_id
|
||||
|
||||
# add encoder_outputs to model keyword arguments
|
||||
model_kwargs = {
|
||||
"encoder_outputs": model.get_encoder()(
|
||||
encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
|
||||
)
|
||||
}
|
||||
|
||||
constraint_str = "sind"
|
||||
constraint_token_ids = tokenizer.encode(constraint_str)[:-1] # remove eos token
|
||||
constraints = [PhrasalConstraint(token_ids=constraint_token_ids)]
|
||||
|
||||
# instantiate beam scorer
|
||||
beam_scorer = ConstrainedBeamSearchScorer(
|
||||
batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints
|
||||
)
|
||||
|
||||
# instantiate logits processors
|
||||
logits_processor = LogitsProcessorList(
|
||||
[
|
||||
MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
|
||||
]
|
||||
)
|
||||
|
||||
outputs = model.constrained_beam_search(
|
||||
input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
|
||||
)
|
||||
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(outputs, ["Wie alter sind Sie?"])
|
||||
|
||||
Reference in New Issue
Block a user