Constrained Beam Search [*With* Disjunctive Decoding] (#15761)
* added classes to get started with constrained beam search * in progress, think i can directly force tokens now but not yet with the round robin * think now i have total control, now need to code the bank selection * technically works as desired, need to optimize and fix design choices leading to undersirable outputs * complete PR #1 without disjunctive decoding * removed incorrect tests * Delete k.txt * Delete test.py * Delete test.sh * revert changes to test scripts * genutils * full implementation with testing, no disjunctive yet * shifted docs * passing all tests realistically ran locally * removing accidentally included print statements * fixed source of error in initial PR test * fixing the get_device() vs device trap * fixed documentation docstrings about constrained_beam_search * fixed tests having failing for Speech2TextModel's floating point inputs * fix cuda long tensor * added examples and testing for them and founx & fixed a bug in beam_search and constrained_beam_search * deleted accidentally added test halting code with assert False * code reformat * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update tests/test_generation_utils.py * fixing based on comments on PR * took out the testing code that should but work fails without the beam search moditification ; style changes * fixing comments issues * docstrings for ConstraintListState * typo in PhrsalConstraint docstring * docstrings improvements * finished adding what is sort of an opinionated implementation of disjunctive generation, but it revealed errors in inner beam search logic during testing. * fixed bug found in constrained beam search that used beam_idx that were not global across all the batches * disjunctive constraint working 100% correctly * passing all tests * Accidentally included mlruns * Update src/transformers/generation_beam_constraints.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/generation_beam_constraints.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * complete overhaul of type complexities and other nits * strict type checks in generate() * fixing second round of feedback by narsil * fixed failing generation test because of type check overhaul * generation test fail fix * fixing test fails Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -229,6 +229,8 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
|||||||
|
|
||||||
[[autodoc]] PhrasalConstraint
|
[[autodoc]] PhrasalConstraint
|
||||||
|
|
||||||
|
[[autodoc]] DisjunctiveConstraint
|
||||||
|
|
||||||
[[autodoc]] ConstraintListState
|
[[autodoc]] ConstraintListState
|
||||||
|
|
||||||
## BeamSearch
|
## BeamSearch
|
||||||
|
|||||||
@@ -623,6 +623,7 @@ if is_torch_available():
|
|||||||
_import_structure["generation_beam_constraints"] = [
|
_import_structure["generation_beam_constraints"] = [
|
||||||
"Constraint",
|
"Constraint",
|
||||||
"ConstraintListState",
|
"ConstraintListState",
|
||||||
|
"DisjunctiveConstraint",
|
||||||
"PhrasalConstraint",
|
"PhrasalConstraint",
|
||||||
]
|
]
|
||||||
_import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer", "ConstrainedBeamSearchScorer"]
|
_import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer", "ConstrainedBeamSearchScorer"]
|
||||||
@@ -2857,7 +2858,12 @@ if TYPE_CHECKING:
|
|||||||
TextDataset,
|
TextDataset,
|
||||||
TextDatasetForNextSentencePrediction,
|
TextDatasetForNextSentencePrediction,
|
||||||
)
|
)
|
||||||
from .generation_beam_constraints import Constraint, ConstraintListState, PhrasalConstraint
|
from .generation_beam_constraints import (
|
||||||
|
Constraint,
|
||||||
|
ConstraintListState,
|
||||||
|
DisjunctiveConstraint,
|
||||||
|
PhrasalConstraint,
|
||||||
|
)
|
||||||
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||||
from .generation_logits_process import (
|
from .generation_logits_process import (
|
||||||
ForcedBOSTokenLogitsProcessor,
|
ForcedBOSTokenLogitsProcessor,
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class Constraint(ABC):
|
class Constraint(ABC):
|
||||||
@@ -137,37 +135,38 @@ class PhrasalConstraint(Constraint):
|
|||||||
The id of the token that must be generated by the output.
|
The id of the token that must be generated by the output.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, token_ids: Union[List[int], torch.LongTensor]):
|
def __init__(self, token_ids: List[int]):
|
||||||
super(Constraint, self).__init__()
|
super(Constraint, self).__init__()
|
||||||
|
|
||||||
is_int_list = isinstance(token_ids, List) and isinstance(token_ids[0], int)
|
if not isinstance(token_ids, list) or len(token_ids) == 0:
|
||||||
is_tensor = isinstance(token_ids, torch.Tensor)
|
raise ValueError(f"`token_ids` has to be a non-emtpy list, but is {token_ids}.")
|
||||||
is_int_tensor = (
|
if any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids):
|
||||||
is_tensor and token_ids.dtype in [torch.int16, torch.int32, torch.int64] and len(token_ids.size()) == 1
|
raise ValueError(f"Each list in `token_ids` has to be a list of positive integers, but is {token_ids}.")
|
||||||
)
|
|
||||||
not_positive = torch.any(token_ids < 0) if is_tensor else len([t for t in token_ids if t < 0]) > 0
|
|
||||||
if isinstance(token_ids, int) or not (is_int_list or is_int_tensor) or not_positive:
|
|
||||||
raise ValueError(f"`token_ids` has to be a single list or tensor of positive integers but is {token_ids}")
|
|
||||||
|
|
||||||
if not is_tensor:
|
|
||||||
token_ids = torch.tensor(token_ids)
|
|
||||||
|
|
||||||
self.token_ids = token_ids
|
self.token_ids = token_ids
|
||||||
|
|
||||||
self.seqlen = self.token_ids.size(0)
|
self.seqlen = len(self.token_ids)
|
||||||
self.fulfilled_idx = -1 # the index of the currently fulfilled step
|
self.fulfilled_idx = -1 # the index of the currently fulfilled step
|
||||||
self.completed = False
|
self.completed = False
|
||||||
|
|
||||||
def advance(self):
|
def advance(self):
|
||||||
|
if self.completed:
|
||||||
|
return None
|
||||||
return self.token_ids[self.fulfilled_idx + 1]
|
return self.token_ids[self.fulfilled_idx + 1]
|
||||||
|
|
||||||
def does_advance(self, token_id: int):
|
def does_advance(self, token_id: int):
|
||||||
|
if not isinstance(token_id, int):
|
||||||
|
raise ValueError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}")
|
||||||
|
|
||||||
if self.completed:
|
if self.completed:
|
||||||
return False
|
return False
|
||||||
# move to cpu to guarantee no device issues.
|
|
||||||
return token_id.cpu() == self.token_ids[self.fulfilled_idx + 1].cpu()
|
return token_id == self.token_ids[self.fulfilled_idx + 1]
|
||||||
|
|
||||||
def update(self, token_id: int):
|
def update(self, token_id: int):
|
||||||
|
if not isinstance(token_id, int):
|
||||||
|
raise ValueError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}")
|
||||||
|
|
||||||
stepped = False
|
stepped = False
|
||||||
completed = False
|
completed = False
|
||||||
reset = False
|
reset = False
|
||||||
@@ -202,6 +201,151 @@ class PhrasalConstraint(Constraint):
|
|||||||
return new_constraint
|
return new_constraint
|
||||||
|
|
||||||
|
|
||||||
|
class DisjunctiveTrie:
|
||||||
|
def __init__(self, nested_token_ids: List[List[int]], no_subsets=True):
|
||||||
|
r"""
|
||||||
|
A helper class that builds a trie with the words represented in `nested_token_ids`.
|
||||||
|
"""
|
||||||
|
self.max_height = max([len(one) for one in nested_token_ids])
|
||||||
|
|
||||||
|
root = dict()
|
||||||
|
for token_ids in nested_token_ids:
|
||||||
|
level = root
|
||||||
|
for tidx, token_id in enumerate(token_ids):
|
||||||
|
if token_id not in level:
|
||||||
|
level[token_id] = dict()
|
||||||
|
|
||||||
|
level = level[token_id]
|
||||||
|
|
||||||
|
if no_subsets and self.has_subsets(root, nested_token_ids):
|
||||||
|
raise ValueError(
|
||||||
|
f"Each list in `nested_token_ids` can't be a complete subset of another list, but is {nested_token_ids}."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.trie = root
|
||||||
|
|
||||||
|
def next_tokens(self, current_seq):
|
||||||
|
"""
|
||||||
|
The next possible tokens that will progress the trie, given the current sequence of tokens in `current_seq`.
|
||||||
|
"""
|
||||||
|
start = self.trie
|
||||||
|
|
||||||
|
for current_token in current_seq:
|
||||||
|
start = start[current_token]
|
||||||
|
|
||||||
|
next_tokens = list(start.keys())
|
||||||
|
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
def reached_leaf(self, current_seq):
|
||||||
|
next_tokens = self.next_tokens(current_seq)
|
||||||
|
|
||||||
|
return len(next_tokens) == 0
|
||||||
|
|
||||||
|
def count_leaves(self, root):
|
||||||
|
next_nodes = list(root.values())
|
||||||
|
if len(next_nodes) == 0:
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return sum([self.count_leaves(nn) for nn in next_nodes])
|
||||||
|
|
||||||
|
def has_subsets(self, trie, nested_token_ids):
|
||||||
|
"""
|
||||||
|
Returns whether # of leaves == # of words. Otherwise some word is a subset of another.
|
||||||
|
"""
|
||||||
|
leaf_count = self.count_leaves(trie)
|
||||||
|
return len(nested_token_ids) != leaf_count
|
||||||
|
|
||||||
|
|
||||||
|
class DisjunctiveConstraint(Constraint):
|
||||||
|
r"""
|
||||||
|
A special [`Constraint`] that is fulfilled by fulfilling just one of several constraints.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nested_token_ids (`List[List[int]]`): a list of words, where each word is a list of ids. This constraint
|
||||||
|
is fulfilled by generating just one from the list of words.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, nested_token_ids: List[List[int]]):
|
||||||
|
super(Constraint, self).__init__()
|
||||||
|
|
||||||
|
if not isinstance(nested_token_ids, list) or len(nested_token_ids) == 0:
|
||||||
|
raise ValueError(f"`nested_token_ids` has to be a non-emtpy list, but is {nested_token_ids}.")
|
||||||
|
if any(not isinstance(token_ids, list) for token_ids in nested_token_ids):
|
||||||
|
raise ValueError(f"`nested_token_ids` has to be a list of lists, but is {nested_token_ids}.")
|
||||||
|
if any(
|
||||||
|
any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
|
||||||
|
for token_ids in nested_token_ids
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Each list in `nested_token_ids` has to be a list of positive integers, but is {nested_token_ids}."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.trie = DisjunctiveTrie(nested_token_ids)
|
||||||
|
self.token_ids = nested_token_ids
|
||||||
|
|
||||||
|
self.seqlen = self.trie.max_height
|
||||||
|
self.current_seq = []
|
||||||
|
self.completed = False
|
||||||
|
|
||||||
|
def advance(self):
|
||||||
|
token_list = self.trie.next_tokens(self.current_seq)
|
||||||
|
|
||||||
|
if len(token_list) == 0:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return token_list
|
||||||
|
|
||||||
|
def does_advance(self, token_id: int):
|
||||||
|
if not isinstance(token_id, int):
|
||||||
|
raise ValueError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")
|
||||||
|
|
||||||
|
next_tokens = self.trie.next_tokens(self.current_seq)
|
||||||
|
|
||||||
|
return token_id in next_tokens
|
||||||
|
|
||||||
|
def update(self, token_id: int):
|
||||||
|
if not isinstance(token_id, int):
|
||||||
|
raise ValueError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")
|
||||||
|
|
||||||
|
stepped = False
|
||||||
|
completed = False
|
||||||
|
reset = False
|
||||||
|
|
||||||
|
if self.does_advance(token_id):
|
||||||
|
self.current_seq.append(token_id)
|
||||||
|
stepped = True
|
||||||
|
else:
|
||||||
|
reset = True
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
completed = self.trie.reached_leaf(self.current_seq)
|
||||||
|
self.completed = completed
|
||||||
|
|
||||||
|
return stepped, completed, reset
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.completed = False
|
||||||
|
self.current_seq = []
|
||||||
|
|
||||||
|
def remaining(self):
|
||||||
|
if self.completed:
|
||||||
|
# since this can be completed without reaching max height
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return self.seqlen - len(self.current_seq)
|
||||||
|
|
||||||
|
def copy(self, stateful=False):
|
||||||
|
new_constraint = DisjunctiveConstraint(self.token_ids)
|
||||||
|
|
||||||
|
if stateful:
|
||||||
|
new_constraint.seq_len = self.seqlen
|
||||||
|
new_constraint.current_seq = self.current_seq
|
||||||
|
new_constraint.completed = self.completed
|
||||||
|
|
||||||
|
return new_constraint
|
||||||
|
|
||||||
|
|
||||||
class ConstraintListState:
|
class ConstraintListState:
|
||||||
r"""
|
r"""
|
||||||
A class for beam scorers to track its progress through a list of constraints.
|
A class for beam scorers to track its progress through a list of constraints.
|
||||||
@@ -215,7 +359,7 @@ class ConstraintListState:
|
|||||||
self.constraints = constraints
|
self.constraints = constraints
|
||||||
|
|
||||||
# max # of steps required to fulfill a given constraint
|
# max # of steps required to fulfill a given constraint
|
||||||
self.max_seqlen = max([c.seqlen for c in constraints if isinstance(c, PhrasalConstraint)])
|
self.max_seqlen = max([c.seqlen for c in constraints])
|
||||||
self.n_constraints = len(constraints)
|
self.n_constraints = len(constraints)
|
||||||
self.completed = False
|
self.completed = False
|
||||||
|
|
||||||
@@ -249,26 +393,33 @@ class ConstraintListState:
|
|||||||
Though we don't care which constraint is fulfilled first, if we are in the progress of fulfilling a constraint,
|
Though we don't care which constraint is fulfilled first, if we are in the progress of fulfilling a constraint,
|
||||||
that's the only one we'll return.
|
that's the only one we'll return.
|
||||||
"""
|
"""
|
||||||
if self.inprogress_constraint is None:
|
|
||||||
token_list = []
|
token_list = []
|
||||||
|
if self.inprogress_constraint is None:
|
||||||
for constraint in self.pending_constraints: # "pending" == "unfulfilled yet"
|
for constraint in self.pending_constraints: # "pending" == "unfulfilled yet"
|
||||||
advance = constraint.advance()
|
advance = constraint.advance()
|
||||||
|
if isinstance(advance, int):
|
||||||
token_list.append(advance)
|
token_list.append(advance)
|
||||||
|
elif isinstance(advance, list):
|
||||||
|
token_list.extend(advance)
|
||||||
else:
|
else:
|
||||||
token_list = [self.inprogress_constraint.advance()]
|
advance = self.inprogress_constraint.advance()
|
||||||
|
if isinstance(advance, int):
|
||||||
|
token_list.append(advance)
|
||||||
|
elif isinstance(advance, list):
|
||||||
|
token_list.extend(advance)
|
||||||
|
|
||||||
if len(token_list) == 0:
|
if len(token_list) == 0:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return torch.stack(token_list)
|
return token_list
|
||||||
|
|
||||||
def reset(self, token_ids: Optional[torch.LongTensor]):
|
def reset(self, token_ids: Optional[List[int]]):
|
||||||
"""
|
"""
|
||||||
token_ids: the tokens generated thus far to reset the state of the progress through constraints.
|
token_ids: the tokens generated thus far to reset the state of the progress through constraints.
|
||||||
"""
|
"""
|
||||||
self.init_state()
|
self.init_state()
|
||||||
|
|
||||||
if token_ids is not None and token_ids.size(0) > 0:
|
if token_ids is not None:
|
||||||
for token in token_ids:
|
for token in token_ids:
|
||||||
# completes or steps **one** constraint
|
# completes or steps **one** constraint
|
||||||
complete, stepped = self.add(token)
|
complete, stepped = self.add(token)
|
||||||
@@ -277,9 +428,10 @@ class ConstraintListState:
|
|||||||
if self.completed:
|
if self.completed:
|
||||||
break
|
break
|
||||||
|
|
||||||
return self
|
def add(self, token_id: int):
|
||||||
|
if not isinstance(token_id, int):
|
||||||
|
raise ValueError(f"`token_id` should be an `int`, but is `{token_id}`.")
|
||||||
|
|
||||||
def add(self, token_id: Union[int, torch.LongTensor]):
|
|
||||||
complete, stepped = False, False
|
complete, stepped = False, False
|
||||||
|
|
||||||
if self.completed:
|
if self.completed:
|
||||||
@@ -324,8 +476,8 @@ class ConstraintListState:
|
|||||||
|
|
||||||
if not stepped:
|
if not stepped:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"constraint.update(token_id) is not yielding incremental progress, "
|
"`constraint.update(token_id)` is not yielding incremental progress, "
|
||||||
"even though constraint.does_advance(token_id) is true."
|
"even though `constraint.does_advance(token_id)` is true."
|
||||||
)
|
)
|
||||||
|
|
||||||
if complete:
|
if complete:
|
||||||
|
|||||||
@@ -443,7 +443,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||||||
|
|
||||||
def check_completes_constraints(self, sequence):
|
def check_completes_constraints(self, sequence):
|
||||||
new_state = self.make_constraint_states(1)[0]
|
new_state = self.make_constraint_states(1)[0]
|
||||||
new_state = new_state.reset(sequence)
|
new_state.reset(sequence)
|
||||||
return new_state.completed
|
return new_state.completed
|
||||||
|
|
||||||
def process(
|
def process(
|
||||||
@@ -484,6 +484,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||||||
- **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of
|
- **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of
|
||||||
all
|
all
|
||||||
non-finished beams.
|
non-finished beams.
|
||||||
|
|
||||||
- **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be
|
- **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be
|
||||||
added
|
added
|
||||||
to the non-finished beam_hypotheses.
|
to the non-finished beam_hypotheses.
|
||||||
@@ -537,7 +538,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||||||
if is_beam_token_worse_than_top_num_beams:
|
if is_beam_token_worse_than_top_num_beams:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
completes_constraint = self.check_completes_constraints(input_ids[batch_beam_idx])
|
completes_constraint = self.check_completes_constraints(input_ids[batch_beam_idx].cpu().tolist())
|
||||||
if completes_constraint:
|
if completes_constraint:
|
||||||
beam_hyp.add(
|
beam_hyp.add(
|
||||||
input_ids[batch_beam_idx].clone(),
|
input_ids[batch_beam_idx].clone(),
|
||||||
@@ -628,23 +629,23 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||||||
# hypotheses.
|
# hypotheses.
|
||||||
|
|
||||||
topk_state = topk_contraint_states[seq_idx]
|
topk_state = topk_contraint_states[seq_idx]
|
||||||
topk_state.reset(full_hypotheses[seq_idx])
|
topk_state.reset(full_hypotheses[seq_idx].cpu().tolist())
|
||||||
|
|
||||||
advance_state = advance_constraint_states[seq_idx]
|
advance_state = advance_constraint_states[seq_idx]
|
||||||
advance_state.reset(pre_seq)
|
advance_state.reset(pre_seq.cpu().tolist())
|
||||||
|
|
||||||
if not advance_state.completed:
|
if not advance_state.completed:
|
||||||
advance_tokens = advance_state.advance()
|
advance_tokens = torch.LongTensor(advance_state.advance()).to(device)
|
||||||
for advance_token in advance_tokens.to(device):
|
for advance_token in advance_tokens:
|
||||||
# since adding each `advance_token` leads to a different hypothesis, create new state instance.
|
# since adding each `advance_token` leads to a different hypothesis, create new state instance.
|
||||||
new_state = advance_state.copy(stateful=True)
|
new_state = advance_state.copy(stateful=True)
|
||||||
new_state.add(advance_token)
|
new_state.add(advance_token.cpu().tolist())
|
||||||
|
|
||||||
advance_seq = torch.cat((pre_seq, advance_token.unsqueeze(0)), -1).cpu().tolist()
|
advance_seq = torch.cat((pre_seq, advance_token.unsqueeze(0)), -1).cpu().tolist()
|
||||||
if advance_seq not in track_new["new_seqs"]:
|
if advance_seq not in track_new["new_seqs"]:
|
||||||
# prevent duplicates, which are basically bound to happen in this process.
|
# prevent duplicates, which are basically bound to happen in this process.
|
||||||
track_new["new_seqs"].append(advance_seq)
|
track_new["new_seqs"].append(advance_seq)
|
||||||
track_new["new_indices"].append(seq_idx)
|
track_new["new_indices"].append(sidx + seq_idx) # idx -> global idx across all the batches
|
||||||
track_new["new_tokens"].append(advance_token)
|
track_new["new_tokens"].append(advance_token)
|
||||||
track_new["new_scores"].append(this_batch_token_scores[seq_idx].take(advance_token))
|
track_new["new_scores"].append(this_batch_token_scores[seq_idx].take(advance_token))
|
||||||
track_new["new_states"].append(new_state)
|
track_new["new_states"].append(new_state)
|
||||||
@@ -673,8 +674,9 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||||||
|
|
||||||
advance_state = advance_constraint_states[seq_idx]
|
advance_state = advance_constraint_states[seq_idx]
|
||||||
|
|
||||||
advance_state.reset(advance_seq)
|
|
||||||
advance_seq = advance_seq.cpu().tolist()
|
advance_seq = advance_seq.cpu().tolist()
|
||||||
|
|
||||||
|
advance_state.reset(advance_seq)
|
||||||
if advance_seq not in track_new["new_seqs"]:
|
if advance_seq not in track_new["new_seqs"]:
|
||||||
# but still don't want to have duplicates
|
# but still don't want to have duplicates
|
||||||
track_new["new_seqs"].append(advance_seq)
|
track_new["new_seqs"].append(advance_seq)
|
||||||
@@ -745,7 +747,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||||||
final_score = final_beam_scores[batch_beam_idx].item()
|
final_score = final_beam_scores[batch_beam_idx].item()
|
||||||
final_tokens = input_ids[batch_beam_idx]
|
final_tokens = input_ids[batch_beam_idx]
|
||||||
|
|
||||||
completes_constraint = self.check_completes_constraints(final_tokens)
|
completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist())
|
||||||
if completes_constraint:
|
if completes_constraint:
|
||||||
beam_hyp.add(final_tokens, final_score)
|
beam_hyp.add(final_tokens, final_score)
|
||||||
ids_collect.append(beam_id)
|
ids_collect.append(beam_id)
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ import torch.distributed as dist
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from .file_utils import ModelOutput
|
from .file_utils import ModelOutput
|
||||||
from .generation_beam_constraints import Constraint
|
from .generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint
|
||||||
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||||
from .generation_logits_process import (
|
from .generation_logits_process import (
|
||||||
EncoderNoRepeatNGramLogitsProcessor,
|
EncoderNoRepeatNGramLogitsProcessor,
|
||||||
@@ -818,6 +818,7 @@ class GenerationMixin:
|
|||||||
typical_p: Optional[float] = None,
|
typical_p: Optional[float] = None,
|
||||||
repetition_penalty: Optional[float] = None,
|
repetition_penalty: Optional[float] = None,
|
||||||
bad_words_ids: Optional[Iterable[int]] = None,
|
bad_words_ids: Optional[Iterable[int]] = None,
|
||||||
|
force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None,
|
||||||
bos_token_id: Optional[int] = None,
|
bos_token_id: Optional[int] = None,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[int] = None,
|
||||||
@@ -904,6 +905,11 @@ class GenerationMixin:
|
|||||||
List of token ids that are not allowed to be generated. In order to get the token ids of the words that
|
List of token ids that are not allowed to be generated. In order to get the token ids of the words that
|
||||||
should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
|
should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
|
||||||
add_special_tokens=False).input_ids`.
|
add_special_tokens=False).input_ids`.
|
||||||
|
force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*):
|
||||||
|
List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple
|
||||||
|
list of words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`,
|
||||||
|
this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081),
|
||||||
|
where one can allow different forms of each word.
|
||||||
num_return_sequences(`int`, *optional*, defaults to 1):
|
num_return_sequences(`int`, *optional*, defaults to 1):
|
||||||
The number of independently computed returned sequences for each element in the batch.
|
The number of independently computed returned sequences for each element in the batch.
|
||||||
max_time(`float`, *optional*, defaults to None):
|
max_time(`float`, *optional*, defaults to None):
|
||||||
@@ -1038,10 +1044,18 @@ class GenerationMixin:
|
|||||||
>>> bad_words_ids = tokenizer(
|
>>> bad_words_ids = tokenizer(
|
||||||
... ["idiot", "stupid", "shut up"], add_prefix_space=True, add_special_tokens=False
|
... ["idiot", "stupid", "shut up"], add_prefix_space=True, add_special_tokens=False
|
||||||
>>> ).input_ids
|
>>> ).input_ids
|
||||||
|
>>> # get tokens of words that we want generated
|
||||||
|
>>> force_words_ids = tokenizer(["runs", "loves"], add_prefix_space=True, add_special_tokens=False).input_ids
|
||||||
>>> # encode input context
|
>>> # encode input context
|
||||||
>>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
|
>>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
|
||||||
>>> # generate sequences without allowing bad_words to be generated
|
>>> # generate sequences without allowing bad_words to be generated
|
||||||
>>> outputs = model.generate(input_ids=input_ids, max_length=20, do_sample=True, bad_words_ids=bad_words_ids)
|
>>> outputs = model.generate(
|
||||||
|
... input_ids=input_ids,
|
||||||
|
... max_length=20,
|
||||||
|
... do_sample=True,
|
||||||
|
... bad_words_ids=bad_words_ids,
|
||||||
|
... force_words_ids=force_words_ids,
|
||||||
|
... )
|
||||||
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
|
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||||
```"""
|
```"""
|
||||||
# 1. Set generation parameters if not already defined
|
# 1. Set generation parameters if not already defined
|
||||||
@@ -1138,14 +1152,20 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 6. determine generation mode
|
# 6. determine generation mode
|
||||||
is_constraint_gen_mode = constraints is not None
|
is_constraint_gen_mode = constraints is not None or force_words_ids is not None
|
||||||
is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False and constraints is None
|
is_greedy_gen_mode = (
|
||||||
is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True and constraints is None
|
(num_beams == 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode
|
||||||
is_beam_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is False and constraints is None
|
|
||||||
is_beam_sample_gen_mode = (
|
|
||||||
(num_beams > 1) and (num_beam_groups == 1) and do_sample is True and constraints is None
|
|
||||||
)
|
)
|
||||||
is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and constraints is None
|
is_sample_gen_mode = (
|
||||||
|
(num_beams == 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode
|
||||||
|
)
|
||||||
|
is_beam_gen_mode = (
|
||||||
|
(num_beams > 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode
|
||||||
|
)
|
||||||
|
is_beam_sample_gen_mode = (
|
||||||
|
(num_beams > 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode
|
||||||
|
)
|
||||||
|
is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and not is_constraint_gen_mode
|
||||||
|
|
||||||
if num_beam_groups > num_beams:
|
if num_beam_groups > num_beams:
|
||||||
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
|
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
|
||||||
@@ -1356,9 +1376,46 @@ class GenerationMixin:
|
|||||||
if num_beam_groups is not None and num_beam_groups > 1:
|
if num_beam_groups is not None and num_beam_groups > 1:
|
||||||
raise ValueError("`num_beam_groups` not supported yet for constrained generation.")
|
raise ValueError("`num_beam_groups` not supported yet for constrained generation.")
|
||||||
|
|
||||||
|
final_constraints = []
|
||||||
|
if constraints is not None:
|
||||||
|
final_constraints = constraints
|
||||||
|
|
||||||
|
if force_words_ids is not None:
|
||||||
|
|
||||||
|
def typeerror():
|
||||||
|
raise ValueError(
|
||||||
|
"`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`"
|
||||||
|
f"of positive integers, but is {force_words_ids}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(force_words_ids, list) or len(force_words_ids) == 0:
|
||||||
|
typeerror()
|
||||||
|
|
||||||
|
for word_ids in force_words_ids:
|
||||||
|
if isinstance(word_ids[0], list):
|
||||||
|
if not isinstance(word_ids, list) or len(word_ids) == 0:
|
||||||
|
typeerror()
|
||||||
|
if any(not isinstance(token_ids, list) for token_ids in word_ids):
|
||||||
|
typeerror()
|
||||||
|
if any(
|
||||||
|
any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
|
||||||
|
for token_ids in word_ids
|
||||||
|
):
|
||||||
|
typeerror()
|
||||||
|
|
||||||
|
constraint = DisjunctiveConstraint(word_ids)
|
||||||
|
else:
|
||||||
|
if not isinstance(word_ids, list) or len(word_ids) == 0:
|
||||||
|
typeerror()
|
||||||
|
if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):
|
||||||
|
typeerror()
|
||||||
|
|
||||||
|
constraint = PhrasalConstraint(word_ids)
|
||||||
|
final_constraints.append(constraint)
|
||||||
|
|
||||||
# 10. prepare beam search scorer
|
# 10. prepare beam search scorer
|
||||||
constrained_beam_scorer = ConstrainedBeamSearchScorer(
|
constrained_beam_scorer = ConstrainedBeamSearchScorer(
|
||||||
constraints=constraints,
|
constraints=final_constraints,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_beams=num_beams,
|
num_beams=num_beams,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
|||||||
@@ -94,6 +94,13 @@ class ConstraintListState(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class DisjunctiveConstraint(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class PhrasalConstraint(metaclass=DummyObject):
|
class PhrasalConstraint(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
115
tests/generation/test_generation_beam_constraints.py
Normal file
115
tests/generation/test_generation_beam_constraints.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2020 The HuggingFace Team Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a clone of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import is_torch_available
|
||||||
|
from transformers.testing_utils import require_torch
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers.generation_beam_constraints import DisjunctiveConstraint
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class ConstraintTest(unittest.TestCase):
|
||||||
|
def test_input_types(self):
|
||||||
|
# For consistency across different places the DisjunctiveConstraint is called,
|
||||||
|
# dc.token_ids is a list of integers. It is also initialized only by integers.
|
||||||
|
|
||||||
|
cset = [[1, 2, 4], [1, 2, 3, 4]]
|
||||||
|
dc = DisjunctiveConstraint(cset)
|
||||||
|
self.assertTrue(isinstance(dc.token_ids, list))
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
DisjunctiveConstraint(torch.LongTensor([[1, 2, 4], [1, 2, 3]]))
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
DisjunctiveConstraint([torch.LongTensor([1, 2, 4]), torch.LongTensor([1, 2, 3, 4, 5])])
|
||||||
|
|
||||||
|
def test_check_illegal_input(self):
|
||||||
|
# We can't have constraints that are complete subsets of another. This leads to a preverse
|
||||||
|
# interpretation of "constraint fulfillment": does generating [1,2,3] fulfill the constraint?
|
||||||
|
# It would mean that it generated [1,2] which fulfills it, but it's in the middle of potentially
|
||||||
|
# fulfilling [1,2,3,4]. If we believe that [1,2,3] does fulfill the constraint, then the algorithm
|
||||||
|
# will necessarily never reach [1,2,3,4], giving users a false sense of control (better to just not allow it).
|
||||||
|
cset = [[1, 2], [1, 2, 3, 4]]
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
DisjunctiveConstraint(cset) # fails here
|
||||||
|
|
||||||
|
def test_example_progression(self):
|
||||||
|
cset = [[1, 2, 3], [1, 2, 4]]
|
||||||
|
|
||||||
|
dc = DisjunctiveConstraint(cset)
|
||||||
|
|
||||||
|
stepped, completed, reset = dc.update(1)
|
||||||
|
desired = stepped is True and completed is False and reset is False
|
||||||
|
self.assertTrue(desired)
|
||||||
|
self.assertTrue(not dc.completed)
|
||||||
|
self.assertTrue(dc.current_seq == [1])
|
||||||
|
|
||||||
|
stepped, completed, reset = dc.update(2)
|
||||||
|
desired = stepped is True and completed is False and reset is False
|
||||||
|
self.assertTrue(desired)
|
||||||
|
self.assertTrue(not dc.completed)
|
||||||
|
self.assertTrue(dc.current_seq == [1, 2])
|
||||||
|
|
||||||
|
stepped, completed, reset = dc.update(3)
|
||||||
|
desired = stepped is True and completed is True and reset is False
|
||||||
|
self.assertTrue(desired)
|
||||||
|
self.assertTrue(dc.completed) # Completed!
|
||||||
|
self.assertTrue(dc.current_seq == [1, 2, 3])
|
||||||
|
|
||||||
|
def test_example_progression_unequal_three_mid_and_reset(self):
|
||||||
|
cset = [[1, 2, 3], [1, 2, 4, 5], [1, 2, 5]]
|
||||||
|
|
||||||
|
dc = DisjunctiveConstraint(cset)
|
||||||
|
|
||||||
|
stepped, completed, reset = dc.update(1)
|
||||||
|
self.assertTrue(not dc.completed)
|
||||||
|
self.assertTrue(dc.current_seq == [1])
|
||||||
|
|
||||||
|
stepped, completed, reset = dc.update(2)
|
||||||
|
self.assertTrue(not dc.completed)
|
||||||
|
self.assertTrue(dc.current_seq == [1, 2])
|
||||||
|
|
||||||
|
stepped, completed, reset = dc.update(4)
|
||||||
|
self.assertTrue(not dc.completed)
|
||||||
|
self.assertTrue(dc.current_seq == [1, 2, 4])
|
||||||
|
|
||||||
|
stepped, completed, reset = dc.update(5)
|
||||||
|
self.assertTrue(dc.completed) # Completed!
|
||||||
|
self.assertTrue(dc.current_seq == [1, 2, 4, 5])
|
||||||
|
|
||||||
|
dc.reset()
|
||||||
|
|
||||||
|
stepped, completed, reset = dc.update(1)
|
||||||
|
self.assertTrue(not dc.completed)
|
||||||
|
self.assertTrue(dc.remaining() == 3)
|
||||||
|
self.assertTrue(dc.current_seq == [1])
|
||||||
|
|
||||||
|
stepped, completed, reset = dc.update(2)
|
||||||
|
self.assertTrue(not dc.completed)
|
||||||
|
self.assertTrue(dc.remaining() == 2)
|
||||||
|
self.assertTrue(dc.current_seq == [1, 2])
|
||||||
|
|
||||||
|
stepped, completed, reset = dc.update(5)
|
||||||
|
self.assertTrue(dc.completed) # Completed!
|
||||||
|
self.assertTrue(dc.remaining() == 0)
|
||||||
|
self.assertTrue(dc.current_seq == [1, 2, 5])
|
||||||
@@ -25,7 +25,7 @@ from ..test_modeling_common import floats_tensor, ids_tensor
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers.generation_beam_constraints import PhrasalConstraint
|
from transformers.generation_beam_constraints import DisjunctiveConstraint, PhrasalConstraint
|
||||||
from transformers.generation_beam_search import BeamHypotheses, BeamSearchScorer, ConstrainedBeamSearchScorer
|
from transformers.generation_beam_search import BeamHypotheses, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||||
|
|
||||||
|
|
||||||
@@ -260,10 +260,10 @@ class ConstrainedBeamSearchTester:
|
|||||||
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
|
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
|
||||||
|
|
||||||
if constraints is None:
|
if constraints is None:
|
||||||
force_tokens = torch.randint(10, 50, (1, 2)).type(torch.LongTensor)[0]
|
force_tokens = torch.randint(10, 50, (1, 2))[0].tolist()
|
||||||
constraints = [
|
disjunctive_tokens = torch.randint(10, 50, (2, 2)).tolist()
|
||||||
PhrasalConstraint(force_tokens),
|
|
||||||
]
|
constraints = [PhrasalConstraint(force_tokens), DisjunctiveConstraint(disjunctive_tokens)]
|
||||||
self.constraints = constraints
|
self.constraints = constraints
|
||||||
# cannot be randomely generated
|
# cannot be randomely generated
|
||||||
self.eos_token_id = vocab_size + 1
|
self.eos_token_id = vocab_size + 1
|
||||||
@@ -331,7 +331,13 @@ class ConstrainedBeamSearchTester:
|
|||||||
):
|
):
|
||||||
# check too many eos tokens
|
# check too many eos tokens
|
||||||
constrained_beam_scorer = self.prepare_constrained_beam_scorer()
|
constrained_beam_scorer = self.prepare_constrained_beam_scorer()
|
||||||
fulfilling_sequence = torch.stack([constraint.token_ids for constraint in self.constraints]).flatten()
|
stacked_token_ids = []
|
||||||
|
for constraint in self.constraints:
|
||||||
|
token_ids = constraint.token_ids
|
||||||
|
token_ids = token_ids[0] if isinstance(token_ids[0], list) else token_ids
|
||||||
|
stacked_token_ids = stacked_token_ids + token_ids
|
||||||
|
|
||||||
|
fulfilling_sequence = torch.LongTensor(stacked_token_ids)
|
||||||
fulfill_len = fulfilling_sequence.size(0)
|
fulfill_len = fulfilling_sequence.size(0)
|
||||||
input_ids[:, :fulfill_len] = fulfilling_sequence
|
input_ids[:, :fulfill_len] = fulfilling_sequence
|
||||||
|
|
||||||
@@ -398,7 +404,14 @@ class ConstrainedBeamSearchTester:
|
|||||||
max_length = self.sequence_length + 1
|
max_length = self.sequence_length + 1
|
||||||
|
|
||||||
# for testing finalize, we do want to have fulfilled constraints
|
# for testing finalize, we do want to have fulfilled constraints
|
||||||
fulfilling_sequence = torch.stack([constraint.token_ids for constraint in self.constraints]).flatten()
|
stacked_token_ids = []
|
||||||
|
for constraint in self.constraints:
|
||||||
|
token_ids = constraint.token_ids
|
||||||
|
token_ids = token_ids[0] if isinstance(token_ids[0], list) else token_ids
|
||||||
|
stacked_token_ids = stacked_token_ids + token_ids
|
||||||
|
|
||||||
|
fulfilling_sequence = torch.LongTensor(stacked_token_ids)
|
||||||
|
|
||||||
fulfill_len = fulfilling_sequence.size(0)
|
fulfill_len = fulfilling_sequence.size(0)
|
||||||
input_ids[:, :fulfill_len] = fulfilling_sequence
|
input_ids[:, :fulfill_len] = fulfilling_sequence
|
||||||
|
|
||||||
@@ -451,9 +464,17 @@ class ConstrainedBeamSearchTester:
|
|||||||
self.parent.assertNotEqual(sequences[2, -1].item(), self.eos_token_id)
|
self.parent.assertNotEqual(sequences[2, -1].item(), self.eos_token_id)
|
||||||
|
|
||||||
# test that the constraint is indeed fulfilled
|
# test that the constraint is indeed fulfilled
|
||||||
for output in sequences:
|
for (output, constraint) in [(s, c) for s in sequences for c in constraints]:
|
||||||
for constraint in constraints:
|
|
||||||
forced_token_ids = constraint.token_ids
|
forced_token_ids = constraint.token_ids
|
||||||
|
if isinstance(forced_token_ids[0], list):
|
||||||
|
# disjunctive case
|
||||||
|
flag = False
|
||||||
|
for token_ids in forced_token_ids:
|
||||||
|
if self._check_sequence_inside_sequence(output, token_ids):
|
||||||
|
flag = True
|
||||||
|
break
|
||||||
|
self.parent.assertEqual(flag, True)
|
||||||
|
else:
|
||||||
self.parent.assertEqual(self._check_sequence_inside_sequence(output, forced_token_ids), True)
|
self.parent.assertEqual(self._check_sequence_inside_sequence(output, forced_token_ids), True)
|
||||||
|
|
||||||
# now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned
|
# now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned
|
||||||
@@ -479,18 +500,23 @@ class ConstrainedBeamSearchTester:
|
|||||||
self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size])
|
self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size])
|
||||||
|
|
||||||
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
|
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
|
||||||
|
# check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1.
|
||||||
# set to same device. we don't care what device.
|
# set to same device. we don't care what device.
|
||||||
tensor_1, tensor_2 = tensor_1.cpu(), tensor_2.cpu()
|
|
||||||
|
|
||||||
in_order = tensor_1.size(0) <= tensor_2.size(0)
|
if not isinstance(tensor_1, list):
|
||||||
|
tensor_1 = tensor_1.cpu().tolist()
|
||||||
|
if not isinstance(tensor_2, list):
|
||||||
|
tensor_2 = tensor_2.cpu().tolist()
|
||||||
|
|
||||||
|
in_order = len(tensor_1) <= len(tensor_2)
|
||||||
longer = tensor_2 if in_order else tensor_1
|
longer = tensor_2 if in_order else tensor_1
|
||||||
shorter = tensor_1 if in_order else tensor_2
|
shorter = tensor_1 if in_order else tensor_2
|
||||||
|
|
||||||
flag = False
|
flag = False
|
||||||
chunk_size = shorter.size(0)
|
chunk_size = len(shorter)
|
||||||
for chunk_idx in range(longer.size(0) - chunk_size + 1):
|
for chunk_idx in range(len(longer) - chunk_size + 1):
|
||||||
subseq = longer[chunk_idx : chunk_idx + chunk_size]
|
subseq = longer[chunk_idx : chunk_idx + chunk_size]
|
||||||
if torch.equal(subseq, shorter):
|
if subseq == shorter:
|
||||||
flag = True
|
flag = True
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ if is_torch_available():
|
|||||||
VisionEncoderDecoderModel,
|
VisionEncoderDecoderModel,
|
||||||
top_k_top_p_filtering,
|
top_k_top_p_filtering,
|
||||||
)
|
)
|
||||||
from transformers.generation_beam_constraints import PhrasalConstraint
|
from transformers.generation_beam_constraints import DisjunctiveConstraint, PhrasalConstraint
|
||||||
from transformers.generation_beam_search import BeamSearchScorer, ConstrainedBeamSearchScorer
|
from transformers.generation_beam_search import BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||||
from transformers.generation_logits_process import (
|
from transformers.generation_logits_process import (
|
||||||
ForcedBOSTokenLogitsProcessor,
|
ForcedBOSTokenLogitsProcessor,
|
||||||
@@ -1202,7 +1202,7 @@ class GenerationTesterMixin:
|
|||||||
min_id = 3
|
min_id = 3
|
||||||
max_id = 100
|
max_id = 100
|
||||||
|
|
||||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0]
|
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
||||||
constraints = [
|
constraints = [
|
||||||
PhrasalConstraint(force_tokens),
|
PhrasalConstraint(force_tokens),
|
||||||
]
|
]
|
||||||
@@ -1227,7 +1227,7 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
# check `generate()` and `constrained_beam_search()` are equal for `num_return_sequences`
|
# check `generate()` and `constrained_beam_search()` are equal for `num_return_sequences`
|
||||||
# Sample constraints
|
# Sample constraints
|
||||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0]
|
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
||||||
constraints = [
|
constraints = [
|
||||||
PhrasalConstraint(force_tokens),
|
PhrasalConstraint(force_tokens),
|
||||||
]
|
]
|
||||||
@@ -1288,7 +1288,7 @@ class GenerationTesterMixin:
|
|||||||
# otherwise this throws an error for Speech2TextModel since its inputs are floating points
|
# otherwise this throws an error for Speech2TextModel since its inputs are floating points
|
||||||
min_id = 3
|
min_id = 3
|
||||||
max_id = 100
|
max_id = 100
|
||||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0]
|
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
||||||
constraints = [
|
constraints = [
|
||||||
PhrasalConstraint(force_tokens),
|
PhrasalConstraint(force_tokens),
|
||||||
]
|
]
|
||||||
@@ -1499,18 +1499,23 @@ class GenerationTesterMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
|
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
|
||||||
|
# check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1.
|
||||||
# set to same device. we don't care what device.
|
# set to same device. we don't care what device.
|
||||||
tensor_1, tensor_2 = tensor_1.cpu(), tensor_2.cpu()
|
|
||||||
|
|
||||||
in_order = tensor_1.size(0) <= tensor_2.size(0)
|
if not isinstance(tensor_1, list):
|
||||||
|
tensor_1 = tensor_1.cpu().tolist()
|
||||||
|
if not isinstance(tensor_2, list):
|
||||||
|
tensor_2 = tensor_2.cpu().tolist()
|
||||||
|
|
||||||
|
in_order = len(tensor_1) <= len(tensor_2)
|
||||||
longer = tensor_2 if in_order else tensor_1
|
longer = tensor_2 if in_order else tensor_1
|
||||||
shorter = tensor_1 if in_order else tensor_2
|
shorter = tensor_1 if in_order else tensor_2
|
||||||
|
|
||||||
flag = False
|
flag = False
|
||||||
chunk_size = shorter.size(0)
|
chunk_size = len(shorter)
|
||||||
for chunk_idx in range(longer.size(0) - chunk_size + 1):
|
for chunk_idx in range(len(longer) - chunk_size + 1):
|
||||||
subseq = longer[chunk_idx : chunk_idx + chunk_size]
|
subseq = longer[chunk_idx : chunk_idx + chunk_size]
|
||||||
if torch.equal(subseq, shorter):
|
if subseq == shorter:
|
||||||
flag = True
|
flag = True
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -2315,8 +2320,8 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
|
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
|
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
|
||||||
|
|
||||||
force_tokens = tokenizer.encode(" scared", return_tensors="pt").to(torch_device)[0]
|
force_tokens = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
|
||||||
force_tokens_2 = tokenizer.encode(" big weapons", return_tensors="pt").to(torch_device)[0]
|
force_tokens_2 = tokenizer("big weapons", add_prefix_space=True, add_special_tokens=False).input_ids
|
||||||
|
|
||||||
constraints = [
|
constraints = [
|
||||||
PhrasalConstraint(force_tokens),
|
PhrasalConstraint(force_tokens),
|
||||||
@@ -2346,6 +2351,105 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_constrained_beam_search_mixed(self):
|
||||||
|
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
|
||||||
|
|
||||||
|
force_phrase = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
|
||||||
|
flexible_phrases = tokenizer(
|
||||||
|
["scream", "screams", "screaming", "screamed"], add_prefix_space=True, add_special_tokens=False
|
||||||
|
).input_ids
|
||||||
|
|
||||||
|
constraints = [
|
||||||
|
PhrasalConstraint(force_phrase),
|
||||||
|
DisjunctiveConstraint(flexible_phrases),
|
||||||
|
]
|
||||||
|
|
||||||
|
starting_text = ["The soldiers", "The child"]
|
||||||
|
|
||||||
|
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids.to(torch_device)
|
||||||
|
|
||||||
|
outputs = model.generate(
|
||||||
|
input_ids,
|
||||||
|
constraints=constraints,
|
||||||
|
num_beams=10,
|
||||||
|
num_return_sequences=1,
|
||||||
|
no_repeat_ngram_size=1,
|
||||||
|
# max_length=20,
|
||||||
|
remove_invalid_values=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
generated_text,
|
||||||
|
[
|
||||||
|
"The soldiers, who were all scared and screaming at each other as they tried to get out of the",
|
||||||
|
"The child was taken to a local hospital where she screamed and scared for her life, police said.",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_constrained_beam_search_mixed_mixin(self):
|
||||||
|
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
|
||||||
|
|
||||||
|
force_word = "scared"
|
||||||
|
force_flexible = ["scream", "screams", "screaming", "screamed"]
|
||||||
|
|
||||||
|
force_words_ids = [
|
||||||
|
tokenizer([force_word], add_prefix_space=True, add_special_tokens=False).input_ids,
|
||||||
|
tokenizer(force_flexible, add_prefix_space=True, add_special_tokens=False).input_ids,
|
||||||
|
]
|
||||||
|
|
||||||
|
starting_text = ["The soldiers", "The child"]
|
||||||
|
|
||||||
|
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids.to(torch_device)
|
||||||
|
|
||||||
|
outputs = model.generate(
|
||||||
|
input_ids,
|
||||||
|
force_words_ids=force_words_ids,
|
||||||
|
num_beams=10,
|
||||||
|
num_return_sequences=1,
|
||||||
|
no_repeat_ngram_size=1,
|
||||||
|
remove_invalid_values=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
generated_text,
|
||||||
|
[
|
||||||
|
"The soldiers, who were all scared and screaming at each other as they tried to get out of the",
|
||||||
|
"The child was taken to a local hospital where she screamed and scared for her life, police said.",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_constrained_beam_search_example_translation_mixin(self):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
||||||
|
|
||||||
|
encoder_input_str = "translate English to German: How old are you?"
|
||||||
|
force_words = ["sind"]
|
||||||
|
|
||||||
|
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
|
||||||
|
force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids
|
||||||
|
|
||||||
|
outputs = model.generate(
|
||||||
|
input_ids,
|
||||||
|
force_words_ids=force_words_ids,
|
||||||
|
num_beams=10,
|
||||||
|
num_return_sequences=1,
|
||||||
|
no_repeat_ngram_size=1,
|
||||||
|
remove_invalid_values=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
|
||||||
|
self.assertListEqual(outputs, ["Wie alter sind Sie?"])
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_constrained_beam_search_example_integration(self):
|
def test_constrained_beam_search_example_integration(self):
|
||||||
tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
||||||
@@ -2389,3 +2493,43 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
|
||||||
self.assertListEqual(outputs, ["Wie alter sind Sie?"])
|
self.assertListEqual(outputs, ["Wie alter sind Sie?"])
|
||||||
|
|
||||||
|
def test_constrained_beam_search_mixin_type_checks(self):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
||||||
|
|
||||||
|
encoder_input_str = "translate English to German: How old are you?"
|
||||||
|
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
force_words = ["sind"]
|
||||||
|
force_words_ids = tokenizer(force_words, return_tensors="pt").input_ids
|
||||||
|
model.generate(
|
||||||
|
input_ids,
|
||||||
|
force_words_ids=force_words_ids,
|
||||||
|
num_beams=10,
|
||||||
|
num_return_sequences=1,
|
||||||
|
no_repeat_ngram_size=1,
|
||||||
|
remove_invalid_values=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
force_words = ["sind"]
|
||||||
|
force_words_ids = [tokenizer(force_words, return_tensors="pt").input_ids]
|
||||||
|
model.generate(
|
||||||
|
input_ids,
|
||||||
|
force_words_ids=force_words_ids,
|
||||||
|
num_beams=10,
|
||||||
|
num_return_sequences=1,
|
||||||
|
no_repeat_ngram_size=1,
|
||||||
|
remove_invalid_values=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
model.generate(input_ids, force_words_ids=[])
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
model.generate(input_ids, force_words_ids=[[-1]])
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
model.generate(input_ids, force_words_ids=[[[-1]]])
|
||||||
|
|||||||
Reference in New Issue
Block a user