From 5c6f57ee75665499c8045a8bf7c73bf2415fba20 Mon Sep 17 00:00:00 2001 From: Chan Woo Kim Date: Sat, 5 Mar 2022 02:18:34 +0900 Subject: [PATCH] Constrained Beam Search [*With* Disjunctive Decoding] (#15761) * added classes to get started with constrained beam search * in progress, think i can directly force tokens now but not yet with the round robin * think now i have total control, now need to code the bank selection * technically works as desired, need to optimize and fix design choices leading to undersirable outputs * complete PR #1 without disjunctive decoding * removed incorrect tests * Delete k.txt * Delete test.py * Delete test.sh * revert changes to test scripts * genutils * full implementation with testing, no disjunctive yet * shifted docs * passing all tests realistically ran locally * removing accidentally included print statements * fixed source of error in initial PR test * fixing the get_device() vs device trap * fixed documentation docstrings about constrained_beam_search * fixed tests having failing for Speech2TextModel's floating point inputs * fix cuda long tensor * added examples and testing for them and founx & fixed a bug in beam_search and constrained_beam_search * deleted accidentally added test halting code with assert False * code reformat * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen * Update tests/test_generation_utils.py * fixing based on comments on PR * took out the testing code that should but work fails without the beam search moditification ; style changes * fixing comments issues * docstrings for ConstraintListState * typo in PhrsalConstraint docstring * docstrings improvements * finished adding what is sort of an opinionated implementation of disjunctive generation, but it revealed errors in inner beam search logic during testing. * fixed bug found in constrained beam search that used beam_idx that were not global across all the batches * disjunctive constraint working 100% correctly * passing all tests * Accidentally included mlruns * Update src/transformers/generation_beam_constraints.py Co-authored-by: Patrick von Platen * Update src/transformers/generation_beam_constraints.py Co-authored-by: Patrick von Platen * complete overhaul of type complexities and other nits * strict type checks in generate() * fixing second round of feedback by narsil * fixed failing generation test because of type check overhaul * generation test fail fix * fixing test fails Co-authored-by: Patrick von Platen --- docs/source/internal/generation_utils.mdx | 2 + src/transformers/__init__.py | 8 +- .../generation_beam_constraints.py | 210 +++++++++++++++--- src/transformers/generation_beam_search.py | 22 +- src/transformers/generation_utils.py | 77 ++++++- src/transformers/utils/dummy_pt_objects.py | 7 + .../test_generation_beam_constraints.py | 115 ++++++++++ .../generation/test_generation_beam_search.py | 56 +++-- tests/generation/test_generation_utils.py | 166 +++++++++++++- 9 files changed, 587 insertions(+), 76 deletions(-) create mode 100644 tests/generation/test_generation_beam_constraints.py diff --git a/docs/source/internal/generation_utils.mdx b/docs/source/internal/generation_utils.mdx index 089dcf3b9c..c3e5f1936b 100644 --- a/docs/source/internal/generation_utils.mdx +++ b/docs/source/internal/generation_utils.mdx @@ -229,6 +229,8 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] PhrasalConstraint +[[autodoc]] DisjunctiveConstraint + [[autodoc]] ConstraintListState ## BeamSearch diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index f7f3295a8d..69f21f0120 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -623,6 +623,7 @@ if is_torch_available(): _import_structure["generation_beam_constraints"] = [ "Constraint", "ConstraintListState", + "DisjunctiveConstraint", "PhrasalConstraint", ] _import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer", "ConstrainedBeamSearchScorer"] @@ -2857,7 +2858,12 @@ if TYPE_CHECKING: TextDataset, TextDatasetForNextSentencePrediction, ) - from .generation_beam_constraints import Constraint, ConstraintListState, PhrasalConstraint + from .generation_beam_constraints import ( + Constraint, + ConstraintListState, + DisjunctiveConstraint, + PhrasalConstraint, + ) from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .generation_logits_process import ( ForcedBOSTokenLogitsProcessor, diff --git a/src/transformers/generation_beam_constraints.py b/src/transformers/generation_beam_constraints.py index 6410d06928..d50796bf82 100644 --- a/src/transformers/generation_beam_constraints.py +++ b/src/transformers/generation_beam_constraints.py @@ -1,7 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Union - -import torch +from typing import List, Optional class Constraint(ABC): @@ -137,37 +135,38 @@ class PhrasalConstraint(Constraint): The id of the token that must be generated by the output. """ - def __init__(self, token_ids: Union[List[int], torch.LongTensor]): + def __init__(self, token_ids: List[int]): super(Constraint, self).__init__() - is_int_list = isinstance(token_ids, List) and isinstance(token_ids[0], int) - is_tensor = isinstance(token_ids, torch.Tensor) - is_int_tensor = ( - is_tensor and token_ids.dtype in [torch.int16, torch.int32, torch.int64] and len(token_ids.size()) == 1 - ) - not_positive = torch.any(token_ids < 0) if is_tensor else len([t for t in token_ids if t < 0]) > 0 - if isinstance(token_ids, int) or not (is_int_list or is_int_tensor) or not_positive: - raise ValueError(f"`token_ids` has to be a single list or tensor of positive integers but is {token_ids}") - - if not is_tensor: - token_ids = torch.tensor(token_ids) + if not isinstance(token_ids, list) or len(token_ids) == 0: + raise ValueError(f"`token_ids` has to be a non-emtpy list, but is {token_ids}.") + if any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids): + raise ValueError(f"Each list in `token_ids` has to be a list of positive integers, but is {token_ids}.") self.token_ids = token_ids - self.seqlen = self.token_ids.size(0) + self.seqlen = len(self.token_ids) self.fulfilled_idx = -1 # the index of the currently fulfilled step self.completed = False def advance(self): + if self.completed: + return None return self.token_ids[self.fulfilled_idx + 1] def does_advance(self, token_id: int): + if not isinstance(token_id, int): + raise ValueError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}") + if self.completed: return False - # move to cpu to guarantee no device issues. - return token_id.cpu() == self.token_ids[self.fulfilled_idx + 1].cpu() + + return token_id == self.token_ids[self.fulfilled_idx + 1] def update(self, token_id: int): + if not isinstance(token_id, int): + raise ValueError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}") + stepped = False completed = False reset = False @@ -202,6 +201,151 @@ class PhrasalConstraint(Constraint): return new_constraint +class DisjunctiveTrie: + def __init__(self, nested_token_ids: List[List[int]], no_subsets=True): + r""" + A helper class that builds a trie with the words represented in `nested_token_ids`. + """ + self.max_height = max([len(one) for one in nested_token_ids]) + + root = dict() + for token_ids in nested_token_ids: + level = root + for tidx, token_id in enumerate(token_ids): + if token_id not in level: + level[token_id] = dict() + + level = level[token_id] + + if no_subsets and self.has_subsets(root, nested_token_ids): + raise ValueError( + f"Each list in `nested_token_ids` can't be a complete subset of another list, but is {nested_token_ids}." + ) + + self.trie = root + + def next_tokens(self, current_seq): + """ + The next possible tokens that will progress the trie, given the current sequence of tokens in `current_seq`. + """ + start = self.trie + + for current_token in current_seq: + start = start[current_token] + + next_tokens = list(start.keys()) + + return next_tokens + + def reached_leaf(self, current_seq): + next_tokens = self.next_tokens(current_seq) + + return len(next_tokens) == 0 + + def count_leaves(self, root): + next_nodes = list(root.values()) + if len(next_nodes) == 0: + return 1 + else: + return sum([self.count_leaves(nn) for nn in next_nodes]) + + def has_subsets(self, trie, nested_token_ids): + """ + Returns whether # of leaves == # of words. Otherwise some word is a subset of another. + """ + leaf_count = self.count_leaves(trie) + return len(nested_token_ids) != leaf_count + + +class DisjunctiveConstraint(Constraint): + r""" + A special [`Constraint`] that is fulfilled by fulfilling just one of several constraints. + + Args: + nested_token_ids (`List[List[int]]`): a list of words, where each word is a list of ids. This constraint + is fulfilled by generating just one from the list of words. + """ + + def __init__(self, nested_token_ids: List[List[int]]): + super(Constraint, self).__init__() + + if not isinstance(nested_token_ids, list) or len(nested_token_ids) == 0: + raise ValueError(f"`nested_token_ids` has to be a non-emtpy list, but is {nested_token_ids}.") + if any(not isinstance(token_ids, list) for token_ids in nested_token_ids): + raise ValueError(f"`nested_token_ids` has to be a list of lists, but is {nested_token_ids}.") + if any( + any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids) + for token_ids in nested_token_ids + ): + raise ValueError( + f"Each list in `nested_token_ids` has to be a list of positive integers, but is {nested_token_ids}." + ) + + self.trie = DisjunctiveTrie(nested_token_ids) + self.token_ids = nested_token_ids + + self.seqlen = self.trie.max_height + self.current_seq = [] + self.completed = False + + def advance(self): + token_list = self.trie.next_tokens(self.current_seq) + + if len(token_list) == 0: + return None + else: + return token_list + + def does_advance(self, token_id: int): + if not isinstance(token_id, int): + raise ValueError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}") + + next_tokens = self.trie.next_tokens(self.current_seq) + + return token_id in next_tokens + + def update(self, token_id: int): + if not isinstance(token_id, int): + raise ValueError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}") + + stepped = False + completed = False + reset = False + + if self.does_advance(token_id): + self.current_seq.append(token_id) + stepped = True + else: + reset = True + self.reset() + + completed = self.trie.reached_leaf(self.current_seq) + self.completed = completed + + return stepped, completed, reset + + def reset(self): + self.completed = False + self.current_seq = [] + + def remaining(self): + if self.completed: + # since this can be completed without reaching max height + return 0 + else: + return self.seqlen - len(self.current_seq) + + def copy(self, stateful=False): + new_constraint = DisjunctiveConstraint(self.token_ids) + + if stateful: + new_constraint.seq_len = self.seqlen + new_constraint.current_seq = self.current_seq + new_constraint.completed = self.completed + + return new_constraint + + class ConstraintListState: r""" A class for beam scorers to track its progress through a list of constraints. @@ -215,7 +359,7 @@ class ConstraintListState: self.constraints = constraints # max # of steps required to fulfill a given constraint - self.max_seqlen = max([c.seqlen for c in constraints if isinstance(c, PhrasalConstraint)]) + self.max_seqlen = max([c.seqlen for c in constraints]) self.n_constraints = len(constraints) self.completed = False @@ -249,26 +393,33 @@ class ConstraintListState: Though we don't care which constraint is fulfilled first, if we are in the progress of fulfilling a constraint, that's the only one we'll return. """ + token_list = [] if self.inprogress_constraint is None: - token_list = [] for constraint in self.pending_constraints: # "pending" == "unfulfilled yet" advance = constraint.advance() - token_list.append(advance) + if isinstance(advance, int): + token_list.append(advance) + elif isinstance(advance, list): + token_list.extend(advance) else: - token_list = [self.inprogress_constraint.advance()] + advance = self.inprogress_constraint.advance() + if isinstance(advance, int): + token_list.append(advance) + elif isinstance(advance, list): + token_list.extend(advance) if len(token_list) == 0: return None else: - return torch.stack(token_list) + return token_list - def reset(self, token_ids: Optional[torch.LongTensor]): + def reset(self, token_ids: Optional[List[int]]): """ token_ids: the tokens generated thus far to reset the state of the progress through constraints. """ self.init_state() - if token_ids is not None and token_ids.size(0) > 0: + if token_ids is not None: for token in token_ids: # completes or steps **one** constraint complete, stepped = self.add(token) @@ -277,9 +428,10 @@ class ConstraintListState: if self.completed: break - return self + def add(self, token_id: int): + if not isinstance(token_id, int): + raise ValueError(f"`token_id` should be an `int`, but is `{token_id}`.") - def add(self, token_id: Union[int, torch.LongTensor]): complete, stepped = False, False if self.completed: @@ -324,8 +476,8 @@ class ConstraintListState: if not stepped: raise Exception( - "constraint.update(token_id) is not yielding incremental progress, " - "even though constraint.does_advance(token_id) is true." + "`constraint.update(token_id)` is not yielding incremental progress, " + "even though `constraint.does_advance(token_id)` is true." ) if complete: diff --git a/src/transformers/generation_beam_search.py b/src/transformers/generation_beam_search.py index 81dc0c5a55..8fd3f94f35 100644 --- a/src/transformers/generation_beam_search.py +++ b/src/transformers/generation_beam_search.py @@ -443,7 +443,7 @@ class ConstrainedBeamSearchScorer(BeamScorer): def check_completes_constraints(self, sequence): new_state = self.make_constraint_states(1)[0] - new_state = new_state.reset(sequence) + new_state.reset(sequence) return new_state.completed def process( @@ -484,6 +484,7 @@ class ConstrainedBeamSearchScorer(BeamScorer): - **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of all non-finished beams. + - **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be added to the non-finished beam_hypotheses. @@ -537,7 +538,7 @@ class ConstrainedBeamSearchScorer(BeamScorer): if is_beam_token_worse_than_top_num_beams: continue - completes_constraint = self.check_completes_constraints(input_ids[batch_beam_idx]) + completes_constraint = self.check_completes_constraints(input_ids[batch_beam_idx].cpu().tolist()) if completes_constraint: beam_hyp.add( input_ids[batch_beam_idx].clone(), @@ -628,23 +629,23 @@ class ConstrainedBeamSearchScorer(BeamScorer): # hypotheses. topk_state = topk_contraint_states[seq_idx] - topk_state.reset(full_hypotheses[seq_idx]) + topk_state.reset(full_hypotheses[seq_idx].cpu().tolist()) advance_state = advance_constraint_states[seq_idx] - advance_state.reset(pre_seq) + advance_state.reset(pre_seq.cpu().tolist()) if not advance_state.completed: - advance_tokens = advance_state.advance() - for advance_token in advance_tokens.to(device): + advance_tokens = torch.LongTensor(advance_state.advance()).to(device) + for advance_token in advance_tokens: # since adding each `advance_token` leads to a different hypothesis, create new state instance. new_state = advance_state.copy(stateful=True) - new_state.add(advance_token) + new_state.add(advance_token.cpu().tolist()) advance_seq = torch.cat((pre_seq, advance_token.unsqueeze(0)), -1).cpu().tolist() if advance_seq not in track_new["new_seqs"]: # prevent duplicates, which are basically bound to happen in this process. track_new["new_seqs"].append(advance_seq) - track_new["new_indices"].append(seq_idx) + track_new["new_indices"].append(sidx + seq_idx) # idx -> global idx across all the batches track_new["new_tokens"].append(advance_token) track_new["new_scores"].append(this_batch_token_scores[seq_idx].take(advance_token)) track_new["new_states"].append(new_state) @@ -673,8 +674,9 @@ class ConstrainedBeamSearchScorer(BeamScorer): advance_state = advance_constraint_states[seq_idx] - advance_state.reset(advance_seq) advance_seq = advance_seq.cpu().tolist() + + advance_state.reset(advance_seq) if advance_seq not in track_new["new_seqs"]: # but still don't want to have duplicates track_new["new_seqs"].append(advance_seq) @@ -745,7 +747,7 @@ class ConstrainedBeamSearchScorer(BeamScorer): final_score = final_beam_scores[batch_beam_idx].item() final_tokens = input_ids[batch_beam_idx] - completes_constraint = self.check_completes_constraints(final_tokens) + completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist()) if completes_constraint: beam_hyp.add(final_tokens, final_score) ids_collect.append(beam_id) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 379d87c484..d9a901d201 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -24,7 +24,7 @@ import torch.distributed as dist from torch import nn from .file_utils import ModelOutput -from .generation_beam_constraints import Constraint +from .generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .generation_logits_process import ( EncoderNoRepeatNGramLogitsProcessor, @@ -818,6 +818,7 @@ class GenerationMixin: typical_p: Optional[float] = None, repetition_penalty: Optional[float] = None, bad_words_ids: Optional[Iterable[int]] = None, + force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None, bos_token_id: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, @@ -904,6 +905,11 @@ class GenerationMixin: List of token ids that are not allowed to be generated. In order to get the token ids of the words that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids`. + force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*): + List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple + list of words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, + this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), + where one can allow different forms of each word. num_return_sequences(`int`, *optional*, defaults to 1): The number of independently computed returned sequences for each element in the batch. max_time(`float`, *optional*, defaults to None): @@ -1038,10 +1044,18 @@ class GenerationMixin: >>> bad_words_ids = tokenizer( ... ["idiot", "stupid", "shut up"], add_prefix_space=True, add_special_tokens=False >>> ).input_ids + >>> # get tokens of words that we want generated + >>> force_words_ids = tokenizer(["runs", "loves"], add_prefix_space=True, add_special_tokens=False).input_ids >>> # encode input context >>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids >>> # generate sequences without allowing bad_words to be generated - >>> outputs = model.generate(input_ids=input_ids, max_length=20, do_sample=True, bad_words_ids=bad_words_ids) + >>> outputs = model.generate( + ... input_ids=input_ids, + ... max_length=20, + ... do_sample=True, + ... bad_words_ids=bad_words_ids, + ... force_words_ids=force_words_ids, + ... ) >>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True)) ```""" # 1. Set generation parameters if not already defined @@ -1138,14 +1152,20 @@ class GenerationMixin: ) # 6. determine generation mode - is_constraint_gen_mode = constraints is not None - is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False and constraints is None - is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True and constraints is None - is_beam_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is False and constraints is None - is_beam_sample_gen_mode = ( - (num_beams > 1) and (num_beam_groups == 1) and do_sample is True and constraints is None + is_constraint_gen_mode = constraints is not None or force_words_ids is not None + is_greedy_gen_mode = ( + (num_beams == 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode ) - is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and constraints is None + is_sample_gen_mode = ( + (num_beams == 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode + ) + is_beam_gen_mode = ( + (num_beams > 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode + ) + is_beam_sample_gen_mode = ( + (num_beams > 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode + ) + is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and not is_constraint_gen_mode if num_beam_groups > num_beams: raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") @@ -1356,9 +1376,46 @@ class GenerationMixin: if num_beam_groups is not None and num_beam_groups > 1: raise ValueError("`num_beam_groups` not supported yet for constrained generation.") + final_constraints = [] + if constraints is not None: + final_constraints = constraints + + if force_words_ids is not None: + + def typeerror(): + raise ValueError( + "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`" + f"of positive integers, but is {force_words_ids}." + ) + + if not isinstance(force_words_ids, list) or len(force_words_ids) == 0: + typeerror() + + for word_ids in force_words_ids: + if isinstance(word_ids[0], list): + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any(not isinstance(token_ids, list) for token_ids in word_ids): + typeerror() + if any( + any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids) + for token_ids in word_ids + ): + typeerror() + + constraint = DisjunctiveConstraint(word_ids) + else: + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids): + typeerror() + + constraint = PhrasalConstraint(word_ids) + final_constraints.append(constraint) + # 10. prepare beam search scorer constrained_beam_scorer = ConstrainedBeamSearchScorer( - constraints=constraints, + constraints=final_constraints, batch_size=batch_size, num_beams=num_beams, device=self.device, diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index e222a5c15d..2f4886dd4b 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -94,6 +94,13 @@ class ConstraintListState(metaclass=DummyObject): requires_backends(self, ["torch"]) +class DisjunctiveConstraint(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class PhrasalConstraint(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/generation/test_generation_beam_constraints.py b/tests/generation/test_generation_beam_constraints.py new file mode 100644 index 0000000000..311cdc1429 --- /dev/null +++ b/tests/generation/test_generation_beam_constraints.py @@ -0,0 +1,115 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Team Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a clone of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from transformers import is_torch_available +from transformers.testing_utils import require_torch + + +if is_torch_available(): + import torch + + from transformers.generation_beam_constraints import DisjunctiveConstraint + + +@require_torch +class ConstraintTest(unittest.TestCase): + def test_input_types(self): + # For consistency across different places the DisjunctiveConstraint is called, + # dc.token_ids is a list of integers. It is also initialized only by integers. + + cset = [[1, 2, 4], [1, 2, 3, 4]] + dc = DisjunctiveConstraint(cset) + self.assertTrue(isinstance(dc.token_ids, list)) + + with self.assertRaises(ValueError): + DisjunctiveConstraint(torch.LongTensor([[1, 2, 4], [1, 2, 3]])) + + with self.assertRaises(ValueError): + DisjunctiveConstraint([torch.LongTensor([1, 2, 4]), torch.LongTensor([1, 2, 3, 4, 5])]) + + def test_check_illegal_input(self): + # We can't have constraints that are complete subsets of another. This leads to a preverse + # interpretation of "constraint fulfillment": does generating [1,2,3] fulfill the constraint? + # It would mean that it generated [1,2] which fulfills it, but it's in the middle of potentially + # fulfilling [1,2,3,4]. If we believe that [1,2,3] does fulfill the constraint, then the algorithm + # will necessarily never reach [1,2,3,4], giving users a false sense of control (better to just not allow it). + cset = [[1, 2], [1, 2, 3, 4]] + + with self.assertRaises(ValueError): + DisjunctiveConstraint(cset) # fails here + + def test_example_progression(self): + cset = [[1, 2, 3], [1, 2, 4]] + + dc = DisjunctiveConstraint(cset) + + stepped, completed, reset = dc.update(1) + desired = stepped is True and completed is False and reset is False + self.assertTrue(desired) + self.assertTrue(not dc.completed) + self.assertTrue(dc.current_seq == [1]) + + stepped, completed, reset = dc.update(2) + desired = stepped is True and completed is False and reset is False + self.assertTrue(desired) + self.assertTrue(not dc.completed) + self.assertTrue(dc.current_seq == [1, 2]) + + stepped, completed, reset = dc.update(3) + desired = stepped is True and completed is True and reset is False + self.assertTrue(desired) + self.assertTrue(dc.completed) # Completed! + self.assertTrue(dc.current_seq == [1, 2, 3]) + + def test_example_progression_unequal_three_mid_and_reset(self): + cset = [[1, 2, 3], [1, 2, 4, 5], [1, 2, 5]] + + dc = DisjunctiveConstraint(cset) + + stepped, completed, reset = dc.update(1) + self.assertTrue(not dc.completed) + self.assertTrue(dc.current_seq == [1]) + + stepped, completed, reset = dc.update(2) + self.assertTrue(not dc.completed) + self.assertTrue(dc.current_seq == [1, 2]) + + stepped, completed, reset = dc.update(4) + self.assertTrue(not dc.completed) + self.assertTrue(dc.current_seq == [1, 2, 4]) + + stepped, completed, reset = dc.update(5) + self.assertTrue(dc.completed) # Completed! + self.assertTrue(dc.current_seq == [1, 2, 4, 5]) + + dc.reset() + + stepped, completed, reset = dc.update(1) + self.assertTrue(not dc.completed) + self.assertTrue(dc.remaining() == 3) + self.assertTrue(dc.current_seq == [1]) + + stepped, completed, reset = dc.update(2) + self.assertTrue(not dc.completed) + self.assertTrue(dc.remaining() == 2) + self.assertTrue(dc.current_seq == [1, 2]) + + stepped, completed, reset = dc.update(5) + self.assertTrue(dc.completed) # Completed! + self.assertTrue(dc.remaining() == 0) + self.assertTrue(dc.current_seq == [1, 2, 5]) diff --git a/tests/generation/test_generation_beam_search.py b/tests/generation/test_generation_beam_search.py index b50be51e1b..3971dcc79c 100644 --- a/tests/generation/test_generation_beam_search.py +++ b/tests/generation/test_generation_beam_search.py @@ -25,7 +25,7 @@ from ..test_modeling_common import floats_tensor, ids_tensor if is_torch_available(): import torch - from transformers.generation_beam_constraints import PhrasalConstraint + from transformers.generation_beam_constraints import DisjunctiveConstraint, PhrasalConstraint from transformers.generation_beam_search import BeamHypotheses, BeamSearchScorer, ConstrainedBeamSearchScorer @@ -260,10 +260,10 @@ class ConstrainedBeamSearchTester: self.num_beam_hyps_to_keep = num_beam_hyps_to_keep if constraints is None: - force_tokens = torch.randint(10, 50, (1, 2)).type(torch.LongTensor)[0] - constraints = [ - PhrasalConstraint(force_tokens), - ] + force_tokens = torch.randint(10, 50, (1, 2))[0].tolist() + disjunctive_tokens = torch.randint(10, 50, (2, 2)).tolist() + + constraints = [PhrasalConstraint(force_tokens), DisjunctiveConstraint(disjunctive_tokens)] self.constraints = constraints # cannot be randomely generated self.eos_token_id = vocab_size + 1 @@ -331,7 +331,13 @@ class ConstrainedBeamSearchTester: ): # check too many eos tokens constrained_beam_scorer = self.prepare_constrained_beam_scorer() - fulfilling_sequence = torch.stack([constraint.token_ids for constraint in self.constraints]).flatten() + stacked_token_ids = [] + for constraint in self.constraints: + token_ids = constraint.token_ids + token_ids = token_ids[0] if isinstance(token_ids[0], list) else token_ids + stacked_token_ids = stacked_token_ids + token_ids + + fulfilling_sequence = torch.LongTensor(stacked_token_ids) fulfill_len = fulfilling_sequence.size(0) input_ids[:, :fulfill_len] = fulfilling_sequence @@ -398,7 +404,14 @@ class ConstrainedBeamSearchTester: max_length = self.sequence_length + 1 # for testing finalize, we do want to have fulfilled constraints - fulfilling_sequence = torch.stack([constraint.token_ids for constraint in self.constraints]).flatten() + stacked_token_ids = [] + for constraint in self.constraints: + token_ids = constraint.token_ids + token_ids = token_ids[0] if isinstance(token_ids[0], list) else token_ids + stacked_token_ids = stacked_token_ids + token_ids + + fulfilling_sequence = torch.LongTensor(stacked_token_ids) + fulfill_len = fulfilling_sequence.size(0) input_ids[:, :fulfill_len] = fulfilling_sequence @@ -451,9 +464,17 @@ class ConstrainedBeamSearchTester: self.parent.assertNotEqual(sequences[2, -1].item(), self.eos_token_id) # test that the constraint is indeed fulfilled - for output in sequences: - for constraint in constraints: - forced_token_ids = constraint.token_ids + for (output, constraint) in [(s, c) for s in sequences for c in constraints]: + forced_token_ids = constraint.token_ids + if isinstance(forced_token_ids[0], list): + # disjunctive case + flag = False + for token_ids in forced_token_ids: + if self._check_sequence_inside_sequence(output, token_ids): + flag = True + break + self.parent.assertEqual(flag, True) + else: self.parent.assertEqual(self._check_sequence_inside_sequence(output, forced_token_ids), True) # now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned @@ -479,18 +500,23 @@ class ConstrainedBeamSearchTester: self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size]) def _check_sequence_inside_sequence(self, tensor_1, tensor_2): + # check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1. # set to same device. we don't care what device. - tensor_1, tensor_2 = tensor_1.cpu(), tensor_2.cpu() - in_order = tensor_1.size(0) <= tensor_2.size(0) + if not isinstance(tensor_1, list): + tensor_1 = tensor_1.cpu().tolist() + if not isinstance(tensor_2, list): + tensor_2 = tensor_2.cpu().tolist() + + in_order = len(tensor_1) <= len(tensor_2) longer = tensor_2 if in_order else tensor_1 shorter = tensor_1 if in_order else tensor_2 flag = False - chunk_size = shorter.size(0) - for chunk_idx in range(longer.size(0) - chunk_size + 1): + chunk_size = len(shorter) + for chunk_idx in range(len(longer) - chunk_size + 1): subseq = longer[chunk_idx : chunk_idx + chunk_size] - if torch.equal(subseq, shorter): + if subseq == shorter: flag = True break diff --git a/tests/generation/test_generation_utils.py b/tests/generation/test_generation_utils.py index dd99b9ff2b..9057691a20 100644 --- a/tests/generation/test_generation_utils.py +++ b/tests/generation/test_generation_utils.py @@ -39,7 +39,7 @@ if is_torch_available(): VisionEncoderDecoderModel, top_k_top_p_filtering, ) - from transformers.generation_beam_constraints import PhrasalConstraint + from transformers.generation_beam_constraints import DisjunctiveConstraint, PhrasalConstraint from transformers.generation_beam_search import BeamSearchScorer, ConstrainedBeamSearchScorer from transformers.generation_logits_process import ( ForcedBOSTokenLogitsProcessor, @@ -1202,7 +1202,7 @@ class GenerationTesterMixin: min_id = 3 max_id = 100 - force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0] + force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0] constraints = [ PhrasalConstraint(force_tokens), ] @@ -1227,7 +1227,7 @@ class GenerationTesterMixin: # check `generate()` and `constrained_beam_search()` are equal for `num_return_sequences` # Sample constraints - force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0] + force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0] constraints = [ PhrasalConstraint(force_tokens), ] @@ -1288,7 +1288,7 @@ class GenerationTesterMixin: # otherwise this throws an error for Speech2TextModel since its inputs are floating points min_id = 3 max_id = 100 - force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0] + force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0] constraints = [ PhrasalConstraint(force_tokens), ] @@ -1499,18 +1499,23 @@ class GenerationTesterMixin: ) def _check_sequence_inside_sequence(self, tensor_1, tensor_2): + # check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1. # set to same device. we don't care what device. - tensor_1, tensor_2 = tensor_1.cpu(), tensor_2.cpu() - in_order = tensor_1.size(0) <= tensor_2.size(0) + if not isinstance(tensor_1, list): + tensor_1 = tensor_1.cpu().tolist() + if not isinstance(tensor_2, list): + tensor_2 = tensor_2.cpu().tolist() + + in_order = len(tensor_1) <= len(tensor_2) longer = tensor_2 if in_order else tensor_1 shorter = tensor_1 if in_order else tensor_2 flag = False - chunk_size = shorter.size(0) - for chunk_idx in range(longer.size(0) - chunk_size + 1): + chunk_size = len(shorter) + for chunk_idx in range(len(longer) - chunk_size + 1): subseq = longer[chunk_idx : chunk_idx + chunk_size] - if torch.equal(subseq, shorter): + if subseq == shorter: flag = True break @@ -2315,8 +2320,8 @@ class GenerationIntegrationTests(unittest.TestCase): model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device) tokenizer = GPT2Tokenizer.from_pretrained("../gpt2") - force_tokens = tokenizer.encode(" scared", return_tensors="pt").to(torch_device)[0] - force_tokens_2 = tokenizer.encode(" big weapons", return_tensors="pt").to(torch_device)[0] + force_tokens = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids + force_tokens_2 = tokenizer("big weapons", add_prefix_space=True, add_special_tokens=False).input_ids constraints = [ PhrasalConstraint(force_tokens), @@ -2346,6 +2351,105 @@ class GenerationIntegrationTests(unittest.TestCase): ], ) + @slow + def test_constrained_beam_search_mixed(self): + model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device) + tokenizer = GPT2Tokenizer.from_pretrained("../gpt2") + + force_phrase = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids + flexible_phrases = tokenizer( + ["scream", "screams", "screaming", "screamed"], add_prefix_space=True, add_special_tokens=False + ).input_ids + + constraints = [ + PhrasalConstraint(force_phrase), + DisjunctiveConstraint(flexible_phrases), + ] + + starting_text = ["The soldiers", "The child"] + + input_ids = tokenizer(starting_text, return_tensors="pt").input_ids.to(torch_device) + + outputs = model.generate( + input_ids, + constraints=constraints, + num_beams=10, + num_return_sequences=1, + no_repeat_ngram_size=1, + # max_length=20, + remove_invalid_values=True, + ) + + generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual( + generated_text, + [ + "The soldiers, who were all scared and screaming at each other as they tried to get out of the", + "The child was taken to a local hospital where she screamed and scared for her life, police said.", + ], + ) + + @slow + def test_constrained_beam_search_mixed_mixin(self): + model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device) + tokenizer = GPT2Tokenizer.from_pretrained("../gpt2") + + force_word = "scared" + force_flexible = ["scream", "screams", "screaming", "screamed"] + + force_words_ids = [ + tokenizer([force_word], add_prefix_space=True, add_special_tokens=False).input_ids, + tokenizer(force_flexible, add_prefix_space=True, add_special_tokens=False).input_ids, + ] + + starting_text = ["The soldiers", "The child"] + + input_ids = tokenizer(starting_text, return_tensors="pt").input_ids.to(torch_device) + + outputs = model.generate( + input_ids, + force_words_ids=force_words_ids, + num_beams=10, + num_return_sequences=1, + no_repeat_ngram_size=1, + remove_invalid_values=True, + ) + + generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual( + generated_text, + [ + "The soldiers, who were all scared and screaming at each other as they tried to get out of the", + "The child was taken to a local hospital where she screamed and scared for her life, police said.", + ], + ) + + @slow + def test_constrained_beam_search_example_translation_mixin(self): + tokenizer = AutoTokenizer.from_pretrained("t5-base") + model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") + + encoder_input_str = "translate English to German: How old are you?" + force_words = ["sind"] + + input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids + force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids + + outputs = model.generate( + input_ids, + force_words_ids=force_words_ids, + num_beams=10, + num_return_sequences=1, + no_repeat_ngram_size=1, + remove_invalid_values=True, + ) + + outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual(outputs, ["Wie alter sind Sie?"]) + @slow def test_constrained_beam_search_example_integration(self): tokenizer = AutoTokenizer.from_pretrained("t5-base") @@ -2389,3 +2493,43 @@ class GenerationIntegrationTests(unittest.TestCase): outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) self.assertListEqual(outputs, ["Wie alter sind Sie?"]) + + def test_constrained_beam_search_mixin_type_checks(self): + tokenizer = AutoTokenizer.from_pretrained("t5-base") + model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") + + encoder_input_str = "translate English to German: How old are you?" + input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids + + with self.assertRaises(ValueError): + force_words = ["sind"] + force_words_ids = tokenizer(force_words, return_tensors="pt").input_ids + model.generate( + input_ids, + force_words_ids=force_words_ids, + num_beams=10, + num_return_sequences=1, + no_repeat_ngram_size=1, + remove_invalid_values=True, + ) + + with self.assertRaises(ValueError): + force_words = ["sind"] + force_words_ids = [tokenizer(force_words, return_tensors="pt").input_ids] + model.generate( + input_ids, + force_words_ids=force_words_ids, + num_beams=10, + num_return_sequences=1, + no_repeat_ngram_size=1, + remove_invalid_values=True, + ) + + with self.assertRaises(ValueError): + model.generate(input_ids, force_words_ids=[]) + + with self.assertRaises(ValueError): + model.generate(input_ids, force_words_ids=[[-1]]) + + with self.assertRaises(ValueError): + model.generate(input_ids, force_words_ids=[[[-1]]])