From 89136ff7f8037fe064b5525e28a54f70f6f770e6 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 20 Jul 2023 12:23:17 +0100 Subject: [PATCH] Generate: sequence bias can handle same terminations (#24822) --- src/transformers/generation/logits_process.py | 36 +++++-------------- tests/generation/test_logits_process.py | 3 ++ 2 files changed, 11 insertions(+), 28 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 5ff46910d8..8a3c8ec7aa 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -624,9 +624,7 @@ class SequenceBiasLogitsProcessor(LogitsProcessor): # Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size # is infered in the first usage, which inhibits initializing here) - self.sequences_length_greater_than_1 = [] self.length_1_bias = None - self.length_greather_than_1_bias = None self.prepared_bias_variables = False @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @@ -642,11 +640,9 @@ class SequenceBiasLogitsProcessor(LogitsProcessor): bias += self.length_1_bias # 4 - include the bias from length > 1, after determining which biased sequences may be completed. - # `matching_mask` is a (batch_size, vocab_size) boolean mask that is True for all tokens whose corresponding - # bias should be applied. The bias is applied on the last token of the sequence, if (and only if) the sequence - # may become complete this iteration. - matching_mask = torch.zeros_like(scores, dtype=torch.bool) - for sequence_ids in self.sequences_length_greater_than_1: + for sequence_ids, sequence_bias in self.sequence_bias.items(): + if len(sequence_ids) == 1: # the sequence is of length 1, already applied + continue if len(sequence_ids) > input_ids.shape[1]: # the sequence is longer than the context, ignore continue prefix_length = len(sequence_ids) - 1 @@ -655,12 +651,9 @@ class SequenceBiasLogitsProcessor(LogitsProcessor): input_ids[:, -prefix_length:], torch.tensor(sequence_ids[:-1], dtype=input_ids.dtype, device=input_ids.device), ).prod(dim=1) - matching_mask[:, last_token] |= matching_rows.bool() - bias += torch.where( - matching_mask, - self.length_greather_than_1_bias, - torch.tensor(0.0, device=self.length_greather_than_1_bias.device), - ) + bias[:, last_token] += torch.where( + matching_rows.bool(), sequence_bias, torch.tensor(0.0, device=input_ids.device) + ) # 5 - apply the bias to the scores scores = scores + bias @@ -668,12 +661,10 @@ class SequenceBiasLogitsProcessor(LogitsProcessor): def _prepare_bias_variables(self, scores: torch.FloatTensor): vocabulary_size = scores.shape[-1] - sequence_bias = self.sequence_bias - tokens_with_bias = [] # Check biased tokens out of bounds invalid_biases = [] - for sequence_ids in sequence_bias: + for sequence_ids in self.sequence_bias: for token_id in sequence_ids: if token_id >= vocabulary_size: invalid_biases.append(token_id) @@ -686,20 +677,9 @@ class SequenceBiasLogitsProcessor(LogitsProcessor): # Precompute the bias tensors to be applied. Sequences of length 1 are kept separately, as they can be applied # with simpler logic. self.length_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float).to(scores.device) - self.length_greather_than_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float).to(scores.device) - for sequence_ids, bias in sequence_bias.items(): + for sequence_ids, bias in self.sequence_bias.items(): if len(sequence_ids) == 1: self.length_1_bias[sequence_ids[-1]] = bias - else: - self.sequences_length_greater_than_1.append(sequence_ids) - if self.length_greather_than_1_bias[sequence_ids[-1]] != 0.0: - raise ValueError( - "Setting a bias on sequences that share a common token termination is not yet supported. " - "Please open an issue if you see this error message (after checking that it doesn't already " - "exist)." - ) - self.length_greather_than_1_bias[sequence_ids[-1]] = bias - tokens_with_bias.append(sequence_ids[-1]) self.prepared_bias_variables = True diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index e560692d4c..fed27097a0 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -520,6 +520,9 @@ class LogitsProcessorTest(unittest.TestCase): input_ids = torch.tensor([[0, 1, 3, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long) positive_bias = {(1,): 100.0, (4,): 100.0} negative_bias = {(1, 0): -100.0, (0, 1, 2): -100.0, (1, 3, 1, 3): -100.0} + # biases the same termination twice, to ensure we can handle overlapping terminations (it won't have an effect + # on the test cases, though) + negative_bias.update({(1, 3, 1, 3, 1, 3): -100.0}) sequence_bias = {**positive_bias, **negative_bias} # scores = 0 to facilitate checks