From 2b5603f6ac58f0cd3b2116c01d6b9f62575248b2 Mon Sep 17 00:00:00 2001 From: Chan Woo Kim Date: Thu, 10 Feb 2022 00:59:26 +0900 Subject: [PATCH] Constrained Beam Search [without disjunctive decoding] (#15416) * added classes to get started with constrained beam search * in progress, think i can directly force tokens now but not yet with the round robin * think now i have total control, now need to code the bank selection * technically works as desired, need to optimize and fix design choices leading to undersirable outputs * complete PR #1 without disjunctive decoding * removed incorrect tests * Delete k.txt * Delete test.py * Delete test.sh * revert changes to test scripts * genutils * full implementation with testing, no disjunctive yet * shifted docs * passing all tests realistically ran locally * removing accidentally included print statements * fixed source of error in initial PR test * fixing the get_device() vs device trap * fixed documentation docstrings about constrained_beam_search * fixed tests having failing for Speech2TextModel's floating point inputs * fix cuda long tensor * added examples and testing for them and founx & fixed a bug in beam_search and constrained_beam_search * deleted accidentally added test halting code with assert False * code reformat * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen * Update tests/test_generation_utils.py * fixing based on comments on PR * took out the testing code that should but work fails without the beam search moditification ; style changes * fixing comments issues * docstrings for ConstraintListState * typo in PhrsalConstraint docstring * docstrings improvements Co-authored-by: Patrick von Platen --- docs/source/internal/generation_utils.mdx | 19 +- src/transformers/__init__.py | 10 +- .../generation_beam_constraints.py | 367 ++++++++++++++ src/transformers/generation_beam_search.py | 456 +++++++++++++++++- src/transformers/generation_utils.py | 378 ++++++++++++++- src/transformers/utils/dummy_pt_objects.py | 28 ++ tests/test_generation_beam_search.py | 285 ++++++++++- tests/test_generation_utils.py | 344 ++++++++++++- 8 files changed, 1871 insertions(+), 16 deletions(-) create mode 100644 src/transformers/generation_beam_constraints.py diff --git a/docs/source/internal/generation_utils.mdx b/docs/source/internal/generation_utils.mdx index 88e5e9e315..9eb4abe06d 100644 --- a/docs/source/internal/generation_utils.mdx +++ b/docs/source/internal/generation_utils.mdx @@ -16,8 +16,9 @@ This page lists all the utility functions used by [`~generation_utils.Generation [`~generation_utils.GenerationMixin.greedy_search`], [`~generation_utils.GenerationMixin.sample`], [`~generation_utils.GenerationMixin.beam_search`], -[`~generation_utils.GenerationMixin.beam_sample`], and -[`~generation_utils.GenerationMixin.group_beam_search`]. +[`~generation_utils.GenerationMixin.beam_sample`], +[`~generation_utils.GenerationMixin.group_beam_search`], and +[`~generation_utils.GenerationMixin.constrained_beam_search`]. Most of those are only useful if you are studying the code of the generate methods in the library. @@ -190,6 +191,16 @@ A [`StoppingCriteria`] can be used to change when to stop generation (other than [[autodoc]] MaxTimeCriteria - __call__ +## Constraints + +A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output. + +[[autodoc]] Constraint + +[[autodoc]] PhrasalConstraint + +[[autodoc]] ConstraintListState + ## BeamSearch [[autodoc]] BeamScorer @@ -200,6 +211,10 @@ A [`StoppingCriteria`] can be used to change when to stop generation (other than - process - finalize +[[autodoc]] ConstrainedBeamSearchScorer + - process + - finalize + ## Utilities [[autodoc]] top_k_top_p_filtering diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 2476d41fc8..f4b0e2908b 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -612,7 +612,12 @@ if is_torch_available(): "TextDatasetForNextSentencePrediction", ] _import_structure["deepspeed"] = [] - _import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer"] + _import_structure["generation_beam_constraints"] = [ + "Constraint", + "ConstraintListState", + "PhrasalConstraint", + ] + _import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer", "ConstrainedBeamSearchScorer"] _import_structure["generation_logits_process"] = [ "ForcedBOSTokenLogitsProcessor", "ForcedEOSTokenLogitsProcessor", @@ -2750,7 +2755,8 @@ if TYPE_CHECKING: TextDataset, TextDatasetForNextSentencePrediction, ) - from .generation_beam_search import BeamScorer, BeamSearchScorer + from .generation_beam_constraints import Constraint, ConstraintListState, PhrasalConstraint + from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .generation_logits_process import ( ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, diff --git a/src/transformers/generation_beam_constraints.py b/src/transformers/generation_beam_constraints.py new file mode 100644 index 0000000000..6410d06928 --- /dev/null +++ b/src/transformers/generation_beam_constraints.py @@ -0,0 +1,367 @@ +from abc import ABC, abstractmethod +from typing import List, Optional, Union + +import torch + + +class Constraint(ABC): + r"""Abstract base class for all constraints that can be applied during generation. + It must define how the constraint can be satisfied. + + All classes that inherit Constraint must follow the requirement that + + ```py + completed = False + while not completed: + _, completed = constraint.update(constraint.advance()) + ``` + + will always terminate (halt). + """ + + def __init__(self): + # test for the above condition + self.test() + + def test(self): + """ + Tests whether this constraint has been properly defined. + """ + counter = 0 + completed = False + while not completed: + if counter == 1: + self.reset() + advance = self.advance() + if not self.does_advance(advance): + raise Exception( + "Custom Constraint is not defined correctly. self.does_advance(self.advance()) must be true." + ) + + stepped, completed, reset = self.update(advance) + counter += 1 + + if counter > 10000: + raise Exception("update() does not fulfill the constraint.") + + if self.remaining() != 0: + raise Exception("Custom Constraint is not defined correctly.") + + @abstractmethod + def advance(self): + """ + When called, returns the token that would take this constraint one step closer to being fulfilled. + + Return: + token_ids(`torch.tensor`): Must be a tensor of a list of indexable tokens, not some integer. + """ + raise NotImplementedError( + f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." + ) + + @abstractmethod + def does_advance(self, token_id: int): + """ + Reads in a token and returns whether it creates progress. + """ + raise NotImplementedError( + f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." + ) + + @abstractmethod + def update(self, token_id: int): + """ + Reads in a token and returns booleans that indicate the progress made by it. This function will update the + state of this object unlikes `does_advance(self, token_id: int)`. + + This isn't to test whether a certain token will advance the progress; it's to update its state as if it has + been generated. This becomes important if token_id != desired token (refer to else statement in + PhrasalConstraint) + + Args: + token_id(`int`): + The id of a newly generated token in the beam search. + Return: + stepped(`bool`): + Whether this constraint has become one step closer to being fulfuilled. + completed(`bool`): + Whether this constraint has been completely fulfilled by this token being generated. + reset (`bool`): + Whether this constraint has reset its progress by this token being generated. + """ + raise NotImplementedError( + f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." + ) + + @abstractmethod + def reset(self): + """ + Resets the state of this constraint to its initialization. We would call this in cases where the fulfillment of + a constraint is abrupted by an unwanted token. + """ + raise NotImplementedError( + f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." + ) + + @abstractmethod + def remaining(self): + """ + Returns the number of remaining steps of `advance()` in order to complete this constraint. + """ + raise NotImplementedError( + f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." + ) + + @abstractmethod + def copy(self, stateful=False): + """ + Creates a new instance of this constraint. + + Args: + stateful(`bool`): Whether to not only copy the constraint for new instance, but also its state. + + Return: + constraint(`Constraint`): The same constraint as the one being called from. + """ + raise NotImplementedError( + f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." + ) + + +class PhrasalConstraint(Constraint): + r""" + [`Constraint`] enforcing that an ordered sequence of tokens is included in the output. + + Args: + token_ids (`List[int]`): + The id of the token that must be generated by the output. + """ + + def __init__(self, token_ids: Union[List[int], torch.LongTensor]): + super(Constraint, self).__init__() + + is_int_list = isinstance(token_ids, List) and isinstance(token_ids[0], int) + is_tensor = isinstance(token_ids, torch.Tensor) + is_int_tensor = ( + is_tensor and token_ids.dtype in [torch.int16, torch.int32, torch.int64] and len(token_ids.size()) == 1 + ) + not_positive = torch.any(token_ids < 0) if is_tensor else len([t for t in token_ids if t < 0]) > 0 + if isinstance(token_ids, int) or not (is_int_list or is_int_tensor) or not_positive: + raise ValueError(f"`token_ids` has to be a single list or tensor of positive integers but is {token_ids}") + + if not is_tensor: + token_ids = torch.tensor(token_ids) + + self.token_ids = token_ids + + self.seqlen = self.token_ids.size(0) + self.fulfilled_idx = -1 # the index of the currently fulfilled step + self.completed = False + + def advance(self): + return self.token_ids[self.fulfilled_idx + 1] + + def does_advance(self, token_id: int): + if self.completed: + return False + # move to cpu to guarantee no device issues. + return token_id.cpu() == self.token_ids[self.fulfilled_idx + 1].cpu() + + def update(self, token_id: int): + stepped = False + completed = False + reset = False + + if self.does_advance(token_id): + self.fulfilled_idx += 1 + stepped = True + if self.fulfilled_idx == (self.seqlen - 1): + completed = True + self.completed = completed + else: + # failed to make progress. + reset = True + self.reset() + return stepped, completed, reset + + def reset(self): + self.completed = False + self.fulfilled_idx = 0 + + def remaining(self): + return self.seqlen - (self.fulfilled_idx + 1) + + def copy(self, stateful=False): + new_constraint = PhrasalConstraint(self.token_ids) + + if stateful: + new_constraint.seq_len = self.seqlen + new_constraint.fulfilled_idx = self.fulfilled_idx + new_constraint.completed = self.completed + + return new_constraint + + +class ConstraintListState: + r""" + A class for beam scorers to track its progress through a list of constraints. + + Args: + constraints (`List[Constraint]`): + A list of [`Constraint`] objects that must be fulfilled by the beam scorer. + """ + + def __init__(self, constraints: List[Constraint]): + self.constraints = constraints + + # max # of steps required to fulfill a given constraint + self.max_seqlen = max([c.seqlen for c in constraints if isinstance(c, PhrasalConstraint)]) + self.n_constraints = len(constraints) + self.completed = False + + self.init_state() + + def init_state(self): + self.complete_constraints = [] + self.inprogress_constraint = None + self.pending_constraints = [constraint.copy(stateful=False) for constraint in self.constraints] + + def get_bank(self): + add = 0 + if self.inprogress_constraint: + # extra points for having a constraint mid-fulfilled + add += self.max_seqlen - self.inprogress_constraint.remaining() + + return (len(self.complete_constraints) * self.max_seqlen) + add + + def advance(self): + """The list of tokens to generate such that we can make progress. + By "list" we don't mean the list of token that will fully fulfill a constraint. + + Given constraints `c_i = {t_ij | j == # of tokens}`, If we're not in the middle of progressing through a + specific constraint `c_i`, we return: + + `[t_k1 for k in indices of unfulfilled constraints]` + + If we are in the middle of a constraint, then we return: + `[t_ij]`, where `i` is the index of the inprogress constraint, `j` is the next step for the constraint. + + Though we don't care which constraint is fulfilled first, if we are in the progress of fulfilling a constraint, + that's the only one we'll return. + """ + if self.inprogress_constraint is None: + token_list = [] + for constraint in self.pending_constraints: # "pending" == "unfulfilled yet" + advance = constraint.advance() + token_list.append(advance) + else: + token_list = [self.inprogress_constraint.advance()] + + if len(token_list) == 0: + return None + else: + return torch.stack(token_list) + + def reset(self, token_ids: Optional[torch.LongTensor]): + """ + token_ids: the tokens generated thus far to reset the state of the progress through constraints. + """ + self.init_state() + + if token_ids is not None and token_ids.size(0) > 0: + for token in token_ids: + # completes or steps **one** constraint + complete, stepped = self.add(token) + + # the entire list of constraints are fulfilled + if self.completed: + break + + return self + + def add(self, token_id: Union[int, torch.LongTensor]): + complete, stepped = False, False + + if self.completed: + complete = True + stepped = False + return complete, stepped + + if self.inprogress_constraint is not None: + # In the middle of fulfilling a constraint. If the `token_id` *does* makes an incremental progress to current + # job, simply update the state + + stepped, complete, reset = self.inprogress_constraint.update(token_id) + if reset: + # 1. If the next token breaks the progress, then we must restart. + # e.g. constraint = "I love pies" and sequence so far is "I love" but `token_id` == "books". + + # But that doesn't mean we self.init_state(), since we only reset the state for this particular + # constraint, not the full list of constraints. + + self.pending_constraints.append(self.inprogress_constraint.copy(stateful=False)) + self.inprogress_constraint = None + + if complete: + # 2. If the next token completes the constraint, move it to completed list, set + # inprogress to None. If there are no pending constraints either, then this full list of constraints + # is complete. + + self.complete_constraints.append(self.inprogress_constraint) + self.inprogress_constraint = None + + if len(self.pending_constraints) == 0: + # we're done! + self.completed = True + + else: + # Not in the middle of fulfilling a constraint. So does this `token_id` helps us step towards any of our list + # of constraints? + + for cidx, pending_constraint in enumerate(self.pending_constraints): + if pending_constraint.does_advance(token_id): + stepped, complete, reset = pending_constraint.update(token_id) + + if not stepped: + raise Exception( + "constraint.update(token_id) is not yielding incremental progress, " + "even though constraint.does_advance(token_id) is true." + ) + + if complete: + self.complete_constraints.append(pending_constraint) + self.inprogress_constraint = None + + if not complete and stepped: + self.inprogress_constraint = pending_constraint + + if complete or stepped: + # If we made any progress at all, then it's at least not a "pending constraint". + + self.pending_constraints = ( + self.pending_constraints[:cidx] + self.pending_constraints[cidx + 1 :] + ) + + if len(self.pending_constraints) == 0 and self.inprogress_constraint is None: + # If there's no longer any pending after this and no inprogress either, then we must be + # complete. + + self.completed = True + + break # prevent accidentally stepping through multiple constraints with just one token. + + return complete, stepped + + def copy(self, stateful=True): + new_state = ConstraintListState(self.constraints) # we actually never though self.constraints objects + # throughout this process. So it's at initialization state. + + if stateful: + new_state.complete_constraints = [ + constraint.copy(stateful=True) for constraint in self.complete_constraints + ] + if self.inprogress_constraint is not None: + new_state.inprogress_constraint = self.inprogress_constraint.copy(stateful=True) + new_state.pending_constraints = [constraint.copy() for constraint in self.pending_constraints] + + return new_state diff --git a/src/transformers/generation_beam_search.py b/src/transformers/generation_beam_search.py index bf84df2c09..81dc0c5a55 100644 --- a/src/transformers/generation_beam_search.py +++ b/src/transformers/generation_beam_search.py @@ -16,11 +16,13 @@ import warnings from abc import ABC, abstractmethod from collections import UserDict -from typing import Optional, Tuple +from typing import List, Optional, Tuple +import numpy as np import torch from .file_utils import add_start_docstrings +from .generation_beam_constraints import Constraint, ConstraintListState PROCESS_INPUTS_DOCSTRING = r""" @@ -336,12 +338,462 @@ class BeamSearchScorer(BeamScorer): if sent_lengths.min().item() != sent_lengths.max().item(): assert pad_token_id is not None, "`pad_token_id` has to be defined" decoded.fill_(pad_token_id) - # fill with hypotheses and eos_token_id if the latter fits in for i, hypo in enumerate(best): decoded[i, : sent_lengths[i]] = hypo if sent_lengths[i] < max_length: decoded[i, sent_lengths[i]] = eos_token_id + + return UserDict( + { + "sequences": decoded, + "sequence_scores": best_scores, + } + ) + + +class ConstrainedBeamSearchScorer(BeamScorer): + r""" + [`BeamScorer`] implementing constrained beam search decoding. + + + Args: + batch_size (`int`): + Batch Size of `input_ids` for which standard beam search decoding is run in parallel. + max_length (`int`): + The maximum length of the sequence to be generated. + num_beams (`int`): + Number of beams for beam search. + constraints (`List[Constraint]`): + A list of positive constraints represented as `Constraint` objects that must be fulfilled in the generation + output. For more information, the documentation of [`Constraint`] should be read. + device (`torch.device`): + Defines the device type (*e.g.*, `"cpu"` or `"cuda"`) on which this instance of `BeamSearchScorer` will be + allocated. + length_penalty (`float`, *optional*, defaults to 1.0): + Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the + model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer + sequences. + do_early_stopping (`bool`, *optional*, defaults to `False`): + Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not. + num_beam_hyps_to_keep (`int`, *optional*, defaults to 1): + The number of beam hypotheses that shall be returned upon calling + [`~transformer.BeamSearchScorer.finalize`]. + num_beam_groups (`int`): + Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. + See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. + """ + + def __init__( + self, + batch_size: int, + num_beams: int, + constraints: List[Constraint], + device: torch.device, + length_penalty: Optional[float] = 1.0, + do_early_stopping: Optional[bool] = False, + num_beam_hyps_to_keep: Optional[int] = 1, + num_beam_groups: Optional[int] = 1, + **kwargs, + ): + self.num_beams = num_beams + self.device = device + self.length_penalty = length_penalty + self.do_early_stopping = do_early_stopping + self.num_beam_hyps_to_keep = num_beam_hyps_to_keep + self.num_beam_groups = num_beam_groups + self.group_size = self.num_beams // self.num_beam_groups + self.constraints = constraints + + self._is_init = False + self._beam_hyps = [ + BeamHypotheses( + num_beams=self.num_beams, + length_penalty=self.length_penalty, + early_stopping=self.do_early_stopping, + ) + for _ in range(batch_size) + ] + self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device) + + if not isinstance(num_beams, int) or num_beams <= 1: + raise ValueError( + f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead." + ) + + if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0): + raise ValueError( + f"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` " + f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." + ) + + if "max_length" in kwargs: + warnings.warn( + "Passing `max_length` to ConstrainedBeamSearchScorer is deprecated and has no effect. " + "`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`" + ", or `group_beam_search(...)`." + ) + + @property + def is_done(self) -> bool: + return self._done.all() + + def make_constraint_states(self, n): + return [ConstraintListState([constraint.copy() for constraint in self.constraints]) for _ in range(n)] + + def check_completes_constraints(self, sequence): + new_state = self.make_constraint_states(1)[0] + new_state = new_state.reset(sequence) + return new_state.completed + + def process( + self, + input_ids: torch.LongTensor, + next_scores: torch.FloatTensor, + next_tokens: torch.LongTensor, + next_indices: torch.LongTensor, + scores_for_all_vocab: torch.FloatTensor, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + ) -> Tuple[torch.Tensor]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + next_scores (`torch.FloatTensor` of shape `(batch_size, 2 * num_beams)`): + Current scores of the top `2 * num_beams` non-finished beam hypotheses. + next_tokens (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`): + `input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses. + next_indices (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`): + Beam indices indicating to which beam hypothesis the `next_tokens` correspond. + scores_for_all_vocab (`torch.FloatTensor` of shape `(batch_size * num_beams, sequence_length)`): + The scores of all tokens in the vocabulary for each of the beam hypotheses. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + + Return: + `UserDict`: A dictionary composed of the fields as defined above: + + - **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of + all + non-finished beams. + - **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be + added + to the non-finished beam_hypotheses. + - **next_beam_indices** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Beam indices + indicating to which beam the next tokens shall be added. + """ + + cur_len = input_ids.shape[-1] + batch_size = len(self._beam_hyps) + if not (batch_size == (input_ids.shape[0] // self.group_size)): + if self.num_beam_groups > 1: + raise ValueError( + f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam " + f"size of {self.group_size} is expected by the beam scorer." + ) + else: + raise ValueError( + f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of " + f"{self.group_size} is expected by the beam scorer." + ) + + device = input_ids.device + + next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device) + next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device) + next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device) + + for batch_idx, beam_hyp in enumerate(self._beam_hyps): + if self._done[batch_idx]: + if self.num_beams < len(beam_hyp): + raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated") + if eos_token_id is None or pad_token_id is None: + raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined") + # pad the batch + next_beam_scores[batch_idx, :] = 0 + next_beam_tokens[batch_idx, :] = pad_token_id + next_beam_indices[batch_idx, :] = 0 + continue + + # next tokens for this sentence. + beam_idx = 0 + for beam_token_rank, (next_token, next_score, next_index) in enumerate( + zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx]) + ): + batch_beam_idx = batch_idx * self.group_size + next_index + # add to generated hypotheses if end of sentence + if (eos_token_id is not None) and (next_token.item() == eos_token_id): + + # if beam_token does not belong to top num_beams tokens, it should not be added + is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size + if is_beam_token_worse_than_top_num_beams: + continue + + completes_constraint = self.check_completes_constraints(input_ids[batch_beam_idx]) + if completes_constraint: + beam_hyp.add( + input_ids[batch_beam_idx].clone(), + next_score.item(), + ) + else: + # add next predicted token since it is not eos_token + next_beam_scores[batch_idx, beam_idx] = next_score + next_beam_tokens[batch_idx, beam_idx] = next_token + next_beam_indices[batch_idx, beam_idx] = batch_beam_idx + beam_idx += 1 + + # once the beam for next step is full, don't add more tokens to it. + if beam_idx == self.group_size: + break + + new_scores, new_tokens, new_indices = self.step_sentence_constraint( + batch_idx, + input_ids, + scores_for_all_vocab, + next_beam_scores[batch_idx], + next_beam_tokens[batch_idx], + next_beam_indices[batch_idx], + ) + + next_beam_scores[batch_idx] = new_scores + next_beam_tokens[batch_idx] = new_tokens + next_beam_indices[batch_idx] = new_indices + + if beam_idx < self.group_size: + raise ValueError( + f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected." + ) + + # Check if we are done so that we can save a pad step if all(done) + self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( + next_scores[batch_idx].max().item(), cur_len + ) + + return UserDict( + { + "next_beam_scores": next_beam_scores.view(-1), + "next_beam_tokens": next_beam_tokens.view(-1), + "next_beam_indices": next_beam_indices.view(-1), + } + ) + + def step_sentence_constraint( + self, + batch_idx: int, + input_ids: torch.LongTensor, + vocab_scores: torch.FloatTensor, + sent_beam_scores: torch.FloatTensor, + sent_beam_tokens: torch.LongTensor, + sent_beam_indices: torch.LongTensor, + push_progress: bool = False, + ): + # sent_beam_tokens are the next {num_beams} number of tokens that are under consideration for this beam + # (candidate next tokens) + + # 1. Adding "advance_tokens" + # using ConstraintStateList.advance(), we propose new tokens to be added into this "candidate list" that will + # advance us in fulfilling the constraints. + + # 2. Selecting best candidates such that we end up with highest probable candidates + # that fulfill our constraints. + + orig_len = sent_beam_indices.size(0) + device = sent_beam_indices.device + + # initialize states + topk_contraint_states = self.make_constraint_states(orig_len) + advance_constraint_states = self.make_constraint_states(orig_len) + + sidx, eidx = batch_idx * orig_len, (batch_idx + 1) * orig_len + this_batch_input_ids = input_ids[sidx:eidx] + this_batch_token_scores = vocab_scores[sidx:eidx] + full_hypotheses = torch.cat((input_ids[sent_beam_indices], sent_beam_tokens.unsqueeze(-1)), dim=-1) + + # need to make new hypothesis that advance the constraints + track_new = {"new_seqs": [], "new_states": [], "new_indices": [], "new_tokens": [], "new_scores": []} + for seq_idx, pre_seq in enumerate(this_batch_input_ids): + # pre_seq = ith sequence generated before this step. + + # input_ids -> (topk) generic beam search best model next tokens + # -> (advance) constraints forcing the next token + # either way, we need to sort them into "banks" later, so store a "ConstraintListState" for all types of + # hypotheses. + + topk_state = topk_contraint_states[seq_idx] + topk_state.reset(full_hypotheses[seq_idx]) + + advance_state = advance_constraint_states[seq_idx] + advance_state.reset(pre_seq) + + if not advance_state.completed: + advance_tokens = advance_state.advance() + for advance_token in advance_tokens.to(device): + # since adding each `advance_token` leads to a different hypothesis, create new state instance. + new_state = advance_state.copy(stateful=True) + new_state.add(advance_token) + + advance_seq = torch.cat((pre_seq, advance_token.unsqueeze(0)), -1).cpu().tolist() + if advance_seq not in track_new["new_seqs"]: + # prevent duplicates, which are basically bound to happen in this process. + track_new["new_seqs"].append(advance_seq) + track_new["new_indices"].append(seq_idx) + track_new["new_tokens"].append(advance_token) + track_new["new_scores"].append(this_batch_token_scores[seq_idx].take(advance_token)) + track_new["new_states"].append(new_state) + elif push_progress: + # Basically, `sent_beam_indices` often chooses very little among `input_ids` the generated sequences that + # actually fulfill our constraints. For example, let constraints == ["loves pies"] and + + # pre_seq_1 = "The child loves pies and" pre_seq_2 = "The child plays in the playground and" + + # Without this step, if `sent_beam_indices` is something like [1,1], then + # 1. `pre_seq_1` won't be added to the list of (topk) hypothesis since it's not in the indices and + # 2. it won't be added to the list of (advance) hypothesis since it's completed already. (this is + # the else part of `if constraints_completed[seq_idx]`) + # 3. it ends up simply getting removed from consideration. + + # #3 might be fine and actually desired, since it's likely that it's a low-probability output anyways, + # especially if it's not in the list of `sent_beam_indices`. But this often leads to lengthened beam + # search times, since completed sequences keep getting removed after all this effort for constrained + # generation. + + # Here, we basically take `pre_seq_1` and to "push" it into the considered list of hypotheses, by simply + # appending the next likely token in the vocabulary and adding it to the list of hypotheses. + + new_score, new_token = torch.max(this_batch_token_scores[seq_idx], 0) # some next probable token + advance_seq = torch.cat((pre_seq, new_token.unsqueeze(0)), -1) + + advance_state = advance_constraint_states[seq_idx] + + advance_state.reset(advance_seq) + advance_seq = advance_seq.cpu().tolist() + if advance_seq not in track_new["new_seqs"]: + # but still don't want to have duplicates + track_new["new_seqs"].append(advance_seq) + track_new["new_indices"].append(seq_idx) + track_new["new_tokens"].append(new_token) + track_new["new_scores"].append(new_score) + track_new["new_states"].append(advance_state) + + if len(track_new["new_indices"]) > 0: + new_indices = torch.tensor(track_new["new_indices"]).to(device) + new_tokens = torch.stack(track_new["new_tokens"]).to(device) + new_scores = torch.stack(track_new["new_scores"]).to(device) + + all_states = topk_contraint_states + track_new["new_states"] + all_tokens = torch.cat((sent_beam_tokens, new_tokens), -1) + all_scores = torch.cat((sent_beam_scores, new_scores), -1) + all_banks = torch.tensor([one.get_bank() for one in all_states]).to(device) + + zipped = all_banks * 100 + all_scores + indices = zipped.sort(descending=True).indices + sorted_banks = all_banks[indices] + + # Then we end up with {sorted among bank C}, {sorted among bank C-1}, ..., {sorted among bank 0} + + counter = -1 + cur_bank = sorted_banks[0] + increments = [] + for bank in sorted_banks: + if bank == cur_bank: + counter += 1 + else: + counter = 0 + cur_bank = bank + increments.append(counter) + rearrangers = torch.tensor(np.argsort(increments, kind="mergesort")) + + indices = indices[rearrangers][:orig_len] + + sent_beam_scores = all_scores[indices] + sent_beam_tokens = all_tokens[indices] + sent_beam_indices = torch.cat((sent_beam_indices, new_indices))[indices] + + return sent_beam_scores, sent_beam_tokens, sent_beam_indices + + def finalize( + self, + input_ids: torch.LongTensor, + final_beam_scores: torch.FloatTensor, + final_beam_tokens: torch.LongTensor, + final_beam_indices: torch.LongTensor, + max_length: int, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + ) -> Tuple[torch.LongTensor]: + batch_size = len(self._beam_hyps) + + # finalize all open beam hypotheses and add to generated hypotheses + for batch_idx, beam_hyp in enumerate(self._beam_hyps): + if self._done[batch_idx]: + continue + + # all open beam hypotheses are added to the beam hypothesis + # beam hypothesis class automatically keeps the best beams + + ids_collect = [] + for beam_id in range(self.num_beams): + batch_beam_idx = batch_idx * self.num_beams + beam_id + final_score = final_beam_scores[batch_beam_idx].item() + final_tokens = input_ids[batch_beam_idx] + + completes_constraint = self.check_completes_constraints(final_tokens) + if completes_constraint: + beam_hyp.add(final_tokens, final_score) + ids_collect.append(beam_id) + + # due to overly complex constraints or other factors, sometimes we can't gaurantee a successful + # generation. In these cases we simply return the highest scoring outputs. + if len(ids_collect) < self.num_beam_hyps_to_keep: + for beam_id in range(self.num_beams): + if beam_id not in ids_collect: + batch_beam_idx = batch_idx * self.num_beams + beam_id + final_score = final_beam_scores[batch_beam_idx].item() + final_tokens = input_ids[batch_beam_idx] + beam_hyp.add(final_tokens, final_score) + if len(ids_collect) >= self.num_beam_hyps_to_keep: + break + + # select the best hypotheses + sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) + best = [] + best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32) + + # retrieve best hypotheses + for i, beam_hyp in enumerate(self._beam_hyps): + sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) + for j in range(self.num_beam_hyps_to_keep): + best_hyp_tuple = sorted_hyps.pop() + best_score = best_hyp_tuple[0] + best_hyp = best_hyp_tuple[1] + sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) + + # append to lists + best.append(best_hyp) + best_scores[i * self.num_beam_hyps_to_keep + j] = best_score + + # prepare for adding eos + sent_lengths_max = sent_lengths.max().item() + 1 + sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max + decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) + # shorter batches are padded if needed + if sent_lengths.min().item() != sent_lengths.max().item(): + assert pad_token_id is not None, "`pad_token_id` has to be defined" + decoded.fill_(pad_token_id) + + # fill with hypotheses and eos_token_id if the latter fits in + for i, hypo in enumerate(best): + decoded[i, : sent_lengths[i]] = hypo + if sent_lengths[i] < sent_max_len: + decoded[i, sent_lengths[i]] = eos_token_id return UserDict( { "sequences": decoded, diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 895ac03114..9e19c51765 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -24,7 +24,8 @@ import torch.distributed as dist from torch import nn from .file_utils import ModelOutput -from .generation_beam_search import BeamScorer, BeamSearchScorer +from .generation_beam_constraints import Constraint +from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .generation_logits_process import ( EncoderNoRepeatNGramLogitsProcessor, ForcedBOSTokenLogitsProcessor, @@ -839,6 +840,7 @@ class GenerationMixin: prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), + constraints: Optional[List[Constraint]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -860,7 +862,6 @@ class GenerationMixin: post](https://huggingface.co/blog/how-to-generate). Parameters: - inputs (`torch.Tensor` of shape `(batch_size, sequence_length)`, `(batch_size, sequence_length, feature_dim)` or `(batch_size, num_channels, height, width)`, *optional*): The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the @@ -945,6 +946,9 @@ class GenerationMixin: Custom stopping criteria that complement the default stopping criteria built from arguments and a model's config. If a stopping criteria is passed that is already created with the arguments or a model's config an error is thrown. This feature is intended for advanced users. + constraints (`List[Constraint]`, *optional*): + Custom constraints that can be added to the generation to ensure that the output will contain the use + of certain tokens as defined by `Constraint` objects, in the most sensible way possible. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. @@ -966,7 +970,6 @@ class GenerationMixin: crash. Note that using `remove_invalid_values` can slow down generation. synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs @@ -1140,11 +1143,14 @@ class GenerationMixin: ) # 6. determine generation mode - is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False - is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True - is_beam_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is False - is_beam_sample_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is True - is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) + is_constraint_gen_mode = constraints is not None + is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False and constraints is None + is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True and constraints is None + is_beam_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is False and constraints is None + is_beam_sample_gen_mode = ( + (num_beams > 1) and (num_beam_groups == 1) and do_sample is True and constraints is None + ) + is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and constraints is None if num_beam_groups > num_beams: raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") @@ -1339,6 +1345,50 @@ class GenerationMixin: **model_kwargs, ) + elif is_constraint_gen_mode: + if num_return_sequences > num_beams: + raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") + + if stopping_criteria.max_length is None: + raise ValueError("`max_length` needs to be a stopping_criteria for now.") + + if num_beams <= 1: + raise ValueError("`num_beams` needs to be greater than 1 for constrained genertation.") + + if do_sample: + raise ValueError("`do_sample` needs to be false for constrained generation.") + + if num_beam_groups is not None and num_beam_groups > 1: + raise ValueError("`num_beam_groups` not supported yet for constrained generation.") + + # 10. prepare beam search scorer + constrained_beam_scorer = ConstrainedBeamSearchScorer( + constraints=constraints, + batch_size=batch_size, + num_beams=num_beams, + device=self.device, + length_penalty=length_penalty, + do_early_stopping=early_stopping, + num_beam_hyps_to_keep=num_return_sequences, + ) + # 11. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs + ) + # 12. run beam search + return self.constrained_beam_search( + input_ids, + constrained_beam_scorer=constrained_beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + def greedy_search( self, input_ids: torch.LongTensor, @@ -2800,6 +2850,318 @@ class GenerationMixin: else: return sequence_outputs["sequences"] + def constrained_beam_search( + self, + input_ids: torch.LongTensor, + constrained_beam_scorer: ConstrainedBeamSearchScorer, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = None, + **model_kwargs, + ) -> Union[BeamSearchOutput, torch.LongTensor]: + + r""" + Generates sequences for models with a language modeling head using beam search decoding. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + constrained_beam_scorer (`ConstrainedBeamSearchScorer`): + A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and + sorted during generation, while satisfying a list of positive constraints. For more information, the + documentation of [`ConstrainedBeamSearchScorer`] should be read. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. + max_length (`int`, *optional*, defaults to 20): + **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated + tokens. The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForSeq2SeqLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... ConstrainedBeamSearchScorer, + ... PhrasalConstraint, + ... ) + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") + >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") + + >>> encoder_input_str = "translate English to German: How old are you?" + >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids + + + >>> # lets run beam search using 3 beams + >>> num_beams = 3 + >>> # define decoder start token ids + >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) + >>> input_ids = input_ids * model.config.decoder_start_token_id + + >>> # add encoder_outputs to model keyword arguments + >>> model_kwargs = { + ... "encoder_outputs": model.get_encoder()( + ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True + ... ) + ... } + + >>> constraint_str = "sind" + >>> constraint_token_ids = tokenizer.encode(constraint_str)[:-1] # slice to remove eos token + >>> constraints = [PhrasalConstraint(token_ids=constraint_token_ids)] + + + >>> # instantiate beam scorer + >>> beam_scorer = ConstrainedBeamSearchScorer( + ... batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints + ... ) + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [ + ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), + ... ] + ... ) + + >>> outputs = model.constrained_beam_search( + ... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs + ... ) + + >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) + # => ['Wie alter sind Sie?'] + ```""" + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + if len(stopping_criteria) == 0: + warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + batch_size = len(constrained_beam_scorer._beam_hyps) + num_beams = constrained_beam_scorer.num_beams + + batch_beam_size, cur_len = input_ids.shape + + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) + + beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) + beam_scores[:, 1:] = -1e9 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + this_peer_finished = False # used by synced_gpus only + while True: + + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if synced_gpus and this_peer_finished: + cur_len = cur_len + 1 + continue # don't waste resources running the code we don't need + + next_token_logits = outputs.logits[:, -1, :] + + # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` + # cannot be generated both before and after the `nn.functional.log_softmax` operation. + next_token_logits = outputs.logits[:, -1, :] + # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` + # cannot be generated both before and after the `nn.functional.log_softmax` operation. + next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) + next_token_scores = nn.functional.log_softmax( + next_token_logits, dim=-1 + ) # (batch_size * num_beams, vocab_size) + + next_token_scores_processed = logits_processor(input_ids, next_token_scores) + + scores_for_all_vocab = next_token_scores_processed.clone() + + next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # reshape for beam search + vocab_size = next_token_scores.shape[-1] + next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) + + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True + ) + + next_indices = (next_tokens / vocab_size).long() + next_tokens = next_tokens % vocab_size + + # stateless + beam_outputs = constrained_beam_scorer.process( + input_ids, + next_token_scores, + next_tokens, + next_indices, + scores_for_all_vocab, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + ) + beam_scores = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + if model_kwargs["past"] is not None: + model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) + + # increase cur_len + cur_len = cur_len + 1 + + if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores): + if not synced_gpus: + break + else: + this_peer_finished = True + + sequence_outputs = constrained_beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + ) + + if return_dict_in_generate: + if not output_scores: + sequence_outputs["sequence_scores"] = None + if self.config.is_encoder_decoder: + return BeamSearchEncoderDecoderOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return BeamSearchDecoderOnlyOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return sequence_outputs["sequences"] + def top_k_top_p_filtering( logits: torch.FloatTensor, diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 567883c7a6..0741e42861 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -80,6 +80,27 @@ class TextDatasetForNextSentencePrediction(metaclass=DummyObject): requires_backends(self, ["torch"]) +class Constraint(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ConstraintListState(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class PhrasalConstraint(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class BeamScorer(metaclass=DummyObject): _backends = ["torch"] @@ -94,6 +115,13 @@ class BeamSearchScorer(metaclass=DummyObject): requires_backends(self, ["torch"]) +class ConstrainedBeamSearchScorer(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class ForcedBOSTokenLogitsProcessor(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/test_generation_beam_search.py b/tests/test_generation_beam_search.py index 11cb8fadeb..339b6f62be 100644 --- a/tests/test_generation_beam_search.py +++ b/tests/test_generation_beam_search.py @@ -25,7 +25,8 @@ from .test_modeling_common import floats_tensor, ids_tensor if is_torch_available(): import torch - from transformers.generation_beam_search import BeamHypotheses, BeamSearchScorer + from transformers.generation_beam_constraints import PhrasalConstraint + from transformers.generation_beam_search import BeamHypotheses, BeamSearchScorer, ConstrainedBeamSearchScorer class BeamSearchTester: @@ -232,6 +233,270 @@ class BeamSearchTester: self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size]) +class ConstrainedBeamSearchTester: + def __init__( + self, + parent, + constraints=None, + batch_size=3, + sequence_length=10, + vocab_size=99, + pad_token_id=0, + max_length=20, + num_beams=4, + length_penalty=2.0, + do_early_stopping=True, + num_beam_hyps_to_keep=2, + ): + self.parent = parent + self.batch_size = batch_size + self.sequence_length = sequence_length + self.vocab_size = vocab_size + self.pad_token_id = pad_token_id + self.max_length = max_length + self.num_beams = num_beams + self.length_penalty = length_penalty + self.do_early_stopping = do_early_stopping + self.num_beam_hyps_to_keep = num_beam_hyps_to_keep + + if constraints is None: + force_tokens = torch.randint(10, 50, (1, 2)).type(torch.LongTensor)[0] + constraints = [ + PhrasalConstraint(force_tokens), + ] + self.constraints = constraints + # cannot be randomely generated + self.eos_token_id = vocab_size + 1 + + def prepare_constrained_beam_scorer(self, **kwargs): + return ConstrainedBeamSearchScorer( + constraints=kwargs.get("constraints", self.constraints), + batch_size=kwargs.get("batch_size", self.batch_size), + num_beams=kwargs.get("num_beams", self.num_beams), + device=torch_device, + length_penalty=kwargs.get("length_penalty", self.length_penalty), + do_early_stopping=kwargs.get("do_early_stopping", self.do_early_stopping), + num_beam_hyps_to_keep=kwargs.get("num_beam_hyps_to_keep", self.num_beam_hyps_to_keep), + ) + + def prepare_inputs(self): + input_ids = ids_tensor((self.batch_size * self.num_beams, self.sequence_length), self.vocab_size) + next_tokens = ids_tensor((self.batch_size, 2 * self.num_beams), self.vocab_size).to(torch_device) + next_indices = ids_tensor((self.batch_size, 2 * self.num_beams), self.num_beams).to(torch_device) + next_scores, _ = (-floats_tensor((self.batch_size, 2 * self.num_beams)).to(torch_device)).sort(descending=True) + scores_for_all_vocab, _ = ( + -floats_tensor((self.batch_size * self.num_beams, self.vocab_size)).to(torch_device) + ).sort(descending=True) + return (input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab) + + def check_beam_hypotheses(self, input_ids, *args): + # check that correct number of beam hypotheses is set in beam scorer + constrained_beam_scorer = self.prepare_constrained_beam_scorer(do_early_stopping=True) + beam_hyp = constrained_beam_scorer._beam_hyps[0] + + self.parent.assertEqual(len(constrained_beam_scorer._beam_hyps), self.batch_size) + + # check correct type + self.parent.assertTrue(isinstance(beam_hyp, BeamHypotheses)) + + # check that num_beams is correctly set + self.parent.assertEqual(beam_hyp.num_beams, self.num_beams) + + # check for early stopping deactivated + for beam_idx in range(self.num_beams): + beam_hyp.add(input_ids[beam_idx], -10.0) + + # if early stopping True -> score does not matter + self.parent.assertTrue(beam_hyp.is_done(-10.0, 5)) + + # re-init + constrained_beam_scorer = self.prepare_constrained_beam_scorer(do_early_stopping=False) + beam_hyp = constrained_beam_scorer._beam_hyps[0] + + # add `num_beams + 1` beams to change `worst_score` + for beam_idx in range(self.num_beams + 1): + beam_hyp.add(input_ids[beam_idx], -10.0 + float(beam_idx)) + + # -10.0 is removed => -9.0 is worst score + self.parent.assertAlmostEqual(beam_hyp.worst_score, -9.0 / (self.sequence_length ** beam_hyp.length_penalty)) + + # -5.0 is better than worst score => should not be finished + self.parent.assertFalse(beam_hyp.is_done(-5.0, self.sequence_length)) + + # -20.0 is worse than worst score => should be finished + self.parent.assertTrue(beam_hyp.is_done(-20.0, self.sequence_length)) + + def check_constrained_beam_scorer_update( + self, input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab + ): + # check too many eos tokens + constrained_beam_scorer = self.prepare_constrained_beam_scorer() + fulfilling_sequence = torch.stack([constraint.token_ids for constraint in self.constraints]).flatten() + fulfill_len = fulfilling_sequence.size(0) + input_ids[:, :fulfill_len] = fulfilling_sequence + + tokens = next_tokens.clone() + tokens[0, :] = self.eos_token_id + + with self.parent.assertRaises(ValueError): + constrained_beam_scorer.process( + input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id + ) + + # check all batches are done + constrained_beam_scorer = self.prepare_constrained_beam_scorer() + + tokens = next_tokens.clone() + tokens[:, : self.num_beams] = self.eos_token_id + constrained_beam_scorer.process( + input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id + ) + # beam scorer should be done + self.parent.assertTrue(constrained_beam_scorer.is_done) + + # check + constrained_beam_scorer = self.prepare_constrained_beam_scorer() + + tokens = next_tokens.clone() + tokens[:, 1] = self.eos_token_id + beam_outputs = constrained_beam_scorer.process( + input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id + ) + output_scores = beam_outputs["next_beam_scores"] + output_tokens = beam_outputs["next_beam_tokens"] + output_indices = beam_outputs["next_beam_indices"] + + def cut_expected_tensor(tensor): + return torch.cat([tensor[:, :1], tensor[:, 2 : self.num_beams + 1]], dim=1).flatten() + + # check all outptus + # cut out id of eos token and take best `num_beams` outputs + expected_output_tokens = cut_expected_tensor(tokens) + expected_output_scores = cut_expected_tensor(next_scores) + + # add num_beams * batch_idx + expected_output_indices = ( + cut_expected_tensor(next_indices) + + (torch.arange(self.num_beams * self.batch_size, device=torch_device) // self.num_beams) * self.num_beams + ) + + self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist()) + self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist()) + self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3)) + + # make sure ids of eos token are correctly saved in beam_hyps of beam scorer + for batch_idx in range(self.batch_size): + correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1] + self.parent.assertListEqual( + input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist() + ) + + def check_constrained_beam_scorer_finalize( + self, input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab + ): + # max_length should be only one more than current input_ids to check that eos is correctly appended + max_length = self.sequence_length + 1 + + # for testing finalize, we do want to have fulfilled constraints + fulfilling_sequence = torch.stack([constraint.token_ids for constraint in self.constraints]).flatten() + fulfill_len = fulfilling_sequence.size(0) + input_ids[:, :fulfill_len] = fulfilling_sequence + + constrained_beam_scorer = self.prepare_constrained_beam_scorer( + num_beam_hyps_to_keep=1, length_penalty=1.0, do_early_stopping=False + ) + + constraints = constrained_beam_scorer.constraints + # update beams and append to input_ids + tokens = next_tokens.clone() + # first batch, first output has to finish with eos token id since scores are correctly sorted + tokens[0, 0] = self.eos_token_id + # make sure corresponding score is as good as possible to surely be picked first + next_scores[0, 0] = 0.0 + + beam_outputs = constrained_beam_scorer.process( + input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id + ) + output_scores = beam_outputs["next_beam_scores"] + output_tokens = beam_outputs["next_beam_tokens"] + output_indices = beam_outputs["next_beam_indices"] + input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1) + + # finalize + sequence_output = constrained_beam_scorer.finalize( + input_ids, + output_scores, + output_tokens, + output_indices, + pad_token_id=self.pad_token_id, + eos_token_id=self.eos_token_id, + max_length=max_length, + ) + + sequences = sequence_output["sequences"] + sequence_scores = sequence_output["sequence_scores"] + + # since `num_beam_hyps_to_keep` = 1 => only return `batch_size` x `max_length` + self.parent.assertListEqual(list(sequences.shape), [self.batch_size, max_length]) + self.parent.assertListEqual(list(sequence_scores.shape), [self.batch_size]) + + # check sequence_scores + self.parent.assertFalse((sequence_scores > 0).any().item()) + + # first batch has to finish with eos_token + self.parent.assertEqual(sequences[0, -1].item(), self.eos_token_id) + + # other batches cannot finish with eos token + self.parent.assertNotEqual(sequences[1, -1].item(), self.eos_token_id) + self.parent.assertNotEqual(sequences[2, -1].item(), self.eos_token_id) + + # test that the constraint is indeed fulfilled + for output in sequences: + for constraint in constraints: + forced_token_ids = constraint.token_ids + self.parent.assertEqual(self._check_sequence_inside_sequence(output, forced_token_ids), True) + + # now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned + + # constrained_beam_scorer.num_beam_hyps_to_keep = self.num_beams + constrained_beam_scorer = self.prepare_constrained_beam_scorer( + num_beam_hyps_to_keep=self.num_beams, length_penalty=1.0, do_early_stopping=False + ) + + sequence_output = constrained_beam_scorer.finalize( + input_ids, + output_scores, + output_tokens, + output_indices, + pad_token_id=self.pad_token_id, + eos_token_id=self.eos_token_id, + max_length=max_length, + ) + sequences = sequence_output["sequences"] + sequence_scores = sequence_output["sequence_scores"] + + self.parent.assertListEqual(list(sequences.shape), [self.num_beams * self.batch_size, max_length]) + self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size]) + + def _check_sequence_inside_sequence(self, tensor_1, tensor_2): + # set to same device. we don't care what device. + tensor_1, tensor_2 = tensor_1.cpu(), tensor_2.cpu() + + in_order = tensor_1.size(0) <= tensor_2.size(0) + longer = tensor_2 if in_order else tensor_1 + shorter = tensor_1 if in_order else tensor_2 + + flag = False + chunk_size = shorter.size(0) + for chunk_idx in range(longer.size(0) - chunk_size + 1): + subseq = longer[chunk_idx : chunk_idx + chunk_size] + if torch.equal(subseq, shorter): + flag = True + break + + return flag + + @require_torch class BeamSearchTest(unittest.TestCase): def setUp(self): @@ -248,3 +513,21 @@ class BeamSearchTest(unittest.TestCase): def test_beam_scorer_finalize(self): inputs = self.beam_search_tester.prepare_inputs() self.beam_search_tester.check_beam_scores_finalize(*inputs) + + +@require_torch +class ConstrainedBeamSearchTest(unittest.TestCase): + def setUp(self): + self.constrained_beam_search_tester = ConstrainedBeamSearchTester(self) + + def test_constrained_beam_hypotheses(self): + inputs = self.constrained_beam_search_tester.prepare_inputs() + self.constrained_beam_search_tester.check_beam_hypotheses(*inputs) + + def test_constrained_beam_scorer_update(self): + inputs = self.constrained_beam_search_tester.prepare_inputs() + self.constrained_beam_search_tester.check_constrained_beam_scorer_update(*inputs) + + def test_constrained_beam_scorer_finalize(self): + inputs = self.constrained_beam_search_tester.prepare_inputs() + self.constrained_beam_search_tester.check_constrained_beam_scorer_finalize(*inputs) diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index 8e49a43095..dbe7c25397 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -27,6 +27,8 @@ if is_torch_available(): import torch from transformers import ( + AutoModelForSeq2SeqLM, + AutoTokenizer, BartForConditionalGeneration, BartTokenizer, GPT2LMHeadModel, @@ -37,7 +39,8 @@ if is_torch_available(): VisionEncoderDecoderModel, top_k_top_p_filtering, ) - from transformers.generation_beam_search import BeamSearchScorer + from transformers.generation_beam_constraints import PhrasalConstraint + from transformers.generation_beam_search import BeamSearchScorer, ConstrainedBeamSearchScorer from transformers.generation_logits_process import ( ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, @@ -190,6 +193,25 @@ class GenerationTesterMixin: ) return beam_kwargs, beam_scorer + @staticmethod + def _get_constrained_beam_scorer_and_kwargs(batch_size, max_length, constraints, num_return_sequences=1): + beam_kwargs = { + "early_stopping": False, + "length_penalty": 2.0, + "num_beams": num_return_sequences * 4, + "num_return_sequences": num_return_sequences, + } + beam_scorer = ConstrainedBeamSearchScorer( + batch_size=batch_size, + constraints=constraints, + num_beams=beam_kwargs["num_beams"], + device=torch_device, + length_penalty=beam_kwargs["length_penalty"], + do_early_stopping=beam_kwargs["early_stopping"], + num_beam_hyps_to_keep=num_return_sequences, + ) + return beam_kwargs, beam_scorer + @staticmethod def _get_encoder_outputs( model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1 @@ -526,6 +548,69 @@ class GenerationTesterMixin: ) return output_generate, output_group_beam_search + def _constrained_beam_search_generate( + self, + model, + input_ids, + attention_mask, + max_length, + constrained_beam_scorer, + constraints, + beam_kwargs, + logits_processor, + logits_process_kwargs, + output_scores=False, + output_attentions=False, + output_hidden_states=False, + return_dict_in_generate=False, + ): + output_generate = model.generate( + input_ids, + attention_mask=attention_mask, + do_sample=False, + max_length=max_length, + output_scores=output_scores, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict_in_generate=return_dict_in_generate, + remove_invalid_values=True, + constraints=constraints, + **beam_kwargs, + **logits_process_kwargs, + ) + + # group_beam_search does not automatically interleave `batch_size` dim for `num_beams` + kwargs = {} + if model.config.is_encoder_decoder: + encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( + model, + input_ids, + attention_mask, + num_interleave=constrained_beam_scorer.num_beams, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + kwargs["encoder_outputs"] = encoder_outputs + input_ids_clone = input_ids_clone.repeat_interleave(constrained_beam_scorer.num_beams, dim=0) + else: + attention_mask_clone = attention_mask.repeat_interleave(constrained_beam_scorer.num_beams, dim=0) + input_ids_clone = input_ids.repeat_interleave(constrained_beam_scorer.num_beams, dim=0) + + with torch.no_grad(): + output_group_beam_search = model.constrained_beam_search( + input_ids_clone, + constrained_beam_scorer, + max_length=max_length, + attention_mask=attention_mask_clone, + logits_processor=logits_processor, + output_scores=output_scores, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict_in_generate=return_dict_in_generate, + **kwargs, + ) + return output_generate, output_group_beam_search + def test_greedy_generate(self): # check `generate()` and `greedy_search()` are equal for model_class in self.all_generative_model_classes: @@ -719,6 +804,7 @@ class GenerationTesterMixin: logits_process_kwargs=logits_process_kwargs, logits_processor=logits_processor, ) + self.assertListEqual(output_generate.tolist(), output_beam_search.tolist()) # check `generate()` and `beam_search()` are equal for `num_return_sequences` @@ -1085,6 +1171,164 @@ class GenerationTesterMixin: output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams ) + def test_constrained_beam_search_generate(self): + for model_class in self.all_generative_model_classes: + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + # It is important set set the eos_token_id to None to ensure that no sequences + # shorter than `max_length` can be generated which could lead to flaky circle ci + # failures if the top `num_return_sequences` beams are all shorter than the longest beam + config.eos_token_id = None + config.forced_eos_token_id = None + + model = model_class(config).to(torch_device).eval() + max_length = 20 + + logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], + config.eos_token_id, + config.forced_bos_token_id, + config.forced_eos_token_id, + max_length, + ) + + # check `generate()` and `constrained_beam_search()` are equal + # Sample constraints + if not input_ids.dtype == torch.float32: + min_id = torch.min(input_ids) + 3 + max_id = torch.max(input_ids) + else: + # otherwise this throws an error for Speech2TextModel since its inputs are floating points + min_id = 3 + max_id = 100 + + force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0] + constraints = [ + PhrasalConstraint(force_tokens), + ] + + beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs( + input_ids.shape[0], max_length, constraints, num_return_sequences=1 + ) + output_generate, output_beam_search = self._constrained_beam_search_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + constrained_beam_scorer=beam_scorer, + constraints=constraints, + beam_kwargs=beam_kwargs, + logits_processor=logits_processor, + logits_process_kwargs=logits_process_kwargs, + ) + self.assertListEqual(output_generate.tolist(), output_beam_search.tolist()) + for generation_output in output_generate: + self._check_sequence_inside_sequence(force_tokens, generation_output) + + # check `generate()` and `constrained_beam_search()` are equal for `num_return_sequences` + # Sample constraints + force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0] + constraints = [ + PhrasalConstraint(force_tokens), + ] + + num_return_sequences = 2 + max_length = 20 + + beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs( + input_ids.shape[0], max_length, constraints, num_return_sequences=num_return_sequences + ) + + output_generate, output_beam_search = self._constrained_beam_search_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + constrained_beam_scorer=beam_scorer, + constraints=constraints, + beam_kwargs=beam_kwargs, + logits_processor=logits_processor, + logits_process_kwargs=logits_process_kwargs, + ) + self.assertListEqual(output_generate.tolist(), output_beam_search.tolist()) + + for generation_output in output_generate: + self._check_sequence_inside_sequence(force_tokens, generation_output) + + def test_constrained_beam_search_generate_dict_output(self): + for model_class in self.all_generative_model_classes: + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + # disable cache + config.use_cache = False + + # It is important set set the eos_token_id to None to ensure that no sequences + # shorter than `max_length` can be generated which could lead to flaky circle ci + # failures if the top `num_return_sequences` beams are all shorter than the longest beam + config.eos_token_id = None + config.forced_eos_token_id = None + + model = model_class(config).to(torch_device).eval() + if model.config.is_encoder_decoder: + max_length = 20 + + logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], + config.eos_token_id, + config.forced_bos_token_id, + config.forced_eos_token_id, + max_length, + ) + + # Sample constraints + if not input_ids.dtype == torch.float32: + min_id = torch.min(input_ids) + 3 + max_id = torch.max(input_ids) + else: + # otherwise this throws an error for Speech2TextModel since its inputs are floating points + min_id = 3 + max_id = 100 + force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0] + constraints = [ + PhrasalConstraint(force_tokens), + ] + + beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs( + input_ids.shape[0], max_length, constraints, num_return_sequences=1 + ) + output_generate, output_beam_search = self._constrained_beam_search_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + constrained_beam_scorer=beam_scorer, + constraints=constraints, + beam_kwargs=beam_kwargs, + logits_processor=logits_processor, + logits_process_kwargs=logits_process_kwargs, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + if model.config.is_encoder_decoder: + self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput) + self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) + else: + self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput) + self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) + + self.assertListEqual(output_generate.sequences.tolist(), output_beam_search.sequences.tolist()) + self.assertTrue( + torch.allclose(output_generate["sequences_scores"], output_beam_search["sequences_scores"], atol=1e-3) + ) + self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],)) + self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) + + for output in (output_beam_search, output_generate): + self._check_outputs(output, input_ids, model.config, num_return_sequences=beam_scorer.num_beams) + def test_generate_with_head_masking(self): """Test designed for encoder-decoder models to ensure the attention head masking is used.""" attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] @@ -1254,6 +1498,24 @@ class GenerationTesterMixin: [encoder_expected_shape] * len(hidden_states), ) + def _check_sequence_inside_sequence(self, tensor_1, tensor_2): + # set to same device. we don't care what device. + tensor_1, tensor_2 = tensor_1.cpu(), tensor_2.cpu() + + in_order = tensor_1.size(0) <= tensor_2.size(0) + longer = tensor_2 if in_order else tensor_1 + shorter = tensor_1 if in_order else tensor_2 + + flag = False + chunk_size = shorter.size(0) + for chunk_idx in range(longer.size(0) - chunk_size + 1): + subseq = longer[chunk_idx : chunk_idx + chunk_size] + if torch.equal(subseq, shorter): + flag = True + break + + self.assertTrue(flag) + @require_torch class UtilsFunctionsTest(unittest.TestCase): @@ -2047,3 +2309,83 @@ class GenerationIntegrationTests(unittest.TestCase): transition_scores_sum = transition_scores.sum(-1) self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3)) + + @slow + def test_constrained_beam_search(self): + model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device) + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + + force_tokens = tokenizer.encode(" scared", return_tensors="pt").to(torch_device)[0] + force_tokens_2 = tokenizer.encode(" big weapons", return_tensors="pt").to(torch_device)[0] + + constraints = [ + PhrasalConstraint(force_tokens), + PhrasalConstraint(force_tokens_2), + ] + + starting_text = ["The soldiers were not prepared and"] + + input_ids = tokenizer(starting_text, return_tensors="pt").input_ids.to(torch_device) + + outputs = model.generate( + input_ids, + constraints=constraints, + num_beams=10, + num_return_sequences=1, + no_repeat_ngram_size=1, + max_length=30, + remove_invalid_values=True, + ) + + generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual( + generated_text, + [ + "The soldiers were not prepared and didn't know how big the big weapons would be, so they scared them off. They had no idea what to do", + ], + ) + + @slow + def test_constrained_beam_search_example_integration(self): + tokenizer = AutoTokenizer.from_pretrained("t5-base") + model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") + + encoder_input_str = "translate English to German: How old are you?" + encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids + + # lets run beam search using 5 beams + num_beams = 5 + # define decoder start token ids + input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) + input_ids = input_ids * model.config.decoder_start_token_id + + # add encoder_outputs to model keyword arguments + model_kwargs = { + "encoder_outputs": model.get_encoder()( + encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True + ) + } + + constraint_str = "sind" + constraint_token_ids = tokenizer.encode(constraint_str)[:-1] # remove eos token + constraints = [PhrasalConstraint(token_ids=constraint_token_ids)] + + # instantiate beam scorer + beam_scorer = ConstrainedBeamSearchScorer( + batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints + ) + + # instantiate logits processors + logits_processor = LogitsProcessorList( + [ + MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), + ] + ) + + outputs = model.constrained_beam_search( + input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs + ) + outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual(outputs, ["Wie alter sind Sie?"])