From 8f2f0f0f85f9e517c495b2083c218215819bae34 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Mon, 26 Feb 2024 21:06:16 +0500 Subject: [PATCH] Track each row separately for stopping criteria (#29116) --- .../generation/stopping_criteria.py | 26 +++++++----- src/transformers/generation/utils.py | 40 ++++++++----------- tests/generation/test_stopping_criteria.py | 22 +++++----- 3 files changed, 43 insertions(+), 45 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index ca3e850964..8516c61572 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -29,7 +29,8 @@ STOPPING_CRITERIA_INPUTS_DOCSTRING = r""" Additional stopping criteria specific kwargs. Return: - `bool`. `False` indicates we should continue, `True` indicates we should stop. + `torch.BoolTensor`. (`torch.BoolTensor` of shape `(batch_size, 1)`), where `True` indicates we stop generation + for a particular row, `True` indicates we should continue. """ @@ -42,7 +43,7 @@ class StoppingCriteria(ABC): """ @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: raise NotImplementedError("StoppingCriteria needs to be subclassed") @@ -63,7 +64,7 @@ class MaxLengthCriteria(StoppingCriteria): self.max_position_embeddings = max_position_embeddings @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: cur_len = input_ids.shape[-1] is_done = cur_len >= self.max_length if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings: @@ -72,7 +73,7 @@ class MaxLengthCriteria(StoppingCriteria): f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe " "exceptions, performance degradation, or nothing at all." ) - return is_done + return torch.full((input_ids.shape[0],), is_done, device=input_ids.device) class MaxNewTokensCriteria(StoppingCriteria): @@ -100,8 +101,9 @@ class MaxNewTokensCriteria(StoppingCriteria): self.max_length = start_length + max_new_tokens @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: - return input_ids.shape[-1] >= self.max_length + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: + is_done = input_ids.shape[-1] >= self.max_length + return torch.full((input_ids.shape[0],), is_done, device=input_ids.device) class MaxTimeCriteria(StoppingCriteria): @@ -122,14 +124,18 @@ class MaxTimeCriteria(StoppingCriteria): self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: - return time.time() - self.initial_timestamp > self.max_time + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: + is_done = time.time() - self.initial_timestamp > self.max_time + return torch.full((input_ids.shape[0],), is_done, device=input_ids.device) class StoppingCriteriaList(list): @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: - return any(criteria(input_ids, scores, **kwargs) for criteria in self) + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: + is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device) + for criteria in self: + is_done = is_done | criteria(input_ids, scores, **kwargs) + return is_done @property def max_length(self) -> Optional[int]: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c7e03123a9..ff5421ad48 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2194,12 +2194,10 @@ class GenerationMixin: next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) ) - # stop when each sentence is finished - if unfinished_sequences.max() == 0: - this_peer_finished = True + # stop when each sentence is finished + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) - # stop if we exceed the maximum length - if stopping_criteria(input_ids, scores): + if unfinished_sequences.max() == 0: this_peer_finished = True if this_peer_finished and not synced_gpus: @@ -2478,12 +2476,10 @@ class GenerationMixin: next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) ) - # stop when each sentence is finished - if unfinished_sequences.max() == 0: - this_peer_finished = True + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) - # stop if we exceed the maximum length - if stopping_criteria(input_ids, scores): + # stop when each sentence is finished + if unfinished_sequences.max() == 0: this_peer_finished = True if this_peer_finished and not synced_gpus: @@ -2772,12 +2768,10 @@ class GenerationMixin: next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) ) - # stop when each sentence is finished - if unfinished_sequences.max() == 0: - this_peer_finished = True + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) - # stop if we exceed the maximum length - if stopping_criteria(input_ids, scores): + # stop when each sentence is finished + if unfinished_sequences.max() == 0: this_peer_finished = True if this_peer_finished and not synced_gpus: @@ -3169,7 +3163,7 @@ class GenerationMixin: # increase cur_len cur_len = cur_len + 1 - if beam_scorer.is_done or stopping_criteria(input_ids, scores): + if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): if not synced_gpus: break else: @@ -3516,7 +3510,7 @@ class GenerationMixin: # increase cur_len cur_len = cur_len + 1 - if beam_scorer.is_done or stopping_criteria(input_ids, scores): + if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): if not synced_gpus: break else: @@ -3912,7 +3906,7 @@ class GenerationMixin: # increase cur_len cur_len = cur_len + 1 - if beam_scorer.is_done or stopping_criteria(input_ids, scores): + if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): if not synced_gpus: break else: @@ -4267,7 +4261,7 @@ class GenerationMixin: # increase cur_len cur_len = cur_len + 1 - if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores): + if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): if not synced_gpus: break else: @@ -4657,12 +4651,10 @@ class GenerationMixin: .prod(dim=0) ) - # stop when each sentence is finished - if unfinished_sequences.max() == 0: - this_peer_finished = True + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) - # stop if we exceed the maximum length - if stopping_criteria(input_ids, scores): + # stop when each sentence is finished + if unfinished_sequences.max() == 0: this_peer_finished = True if this_peer_finished and not synced_gpus: diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index dfc5308359..7fa118c9e3 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -54,37 +54,37 @@ class StoppingCriteriaTestCase(unittest.TestCase): ] ) - self.assertFalse(criteria(input_ids, scores)) + self.assertFalse(all(criteria(input_ids, scores))) input_ids, scores = self._get_tensors(9) - self.assertFalse(criteria(input_ids, scores)) + self.assertFalse(all(criteria(input_ids, scores))) input_ids, scores = self._get_tensors(10) - self.assertTrue(criteria(input_ids, scores)) + self.assertTrue(all(criteria(input_ids, scores))) def test_max_length_criteria(self): criteria = MaxLengthCriteria(max_length=10) input_ids, scores = self._get_tensors(5) - self.assertFalse(criteria(input_ids, scores)) + self.assertFalse(all(criteria(input_ids, scores))) input_ids, scores = self._get_tensors(9) - self.assertFalse(criteria(input_ids, scores)) + self.assertFalse(all(criteria(input_ids, scores))) input_ids, scores = self._get_tensors(10) - self.assertTrue(criteria(input_ids, scores)) + self.assertTrue(all(criteria(input_ids, scores))) def test_max_new_tokens_criteria(self): criteria = MaxNewTokensCriteria(start_length=5, max_new_tokens=5) input_ids, scores = self._get_tensors(5) - self.assertFalse(criteria(input_ids, scores)) + self.assertFalse(all(criteria(input_ids, scores))) input_ids, scores = self._get_tensors(9) - self.assertFalse(criteria(input_ids, scores)) + self.assertFalse(all(criteria(input_ids, scores))) input_ids, scores = self._get_tensors(10) - self.assertTrue(criteria(input_ids, scores)) + self.assertTrue(all(criteria(input_ids, scores))) criteria_list = StoppingCriteriaList([criteria]) self.assertEqual(criteria_list.max_length, 10) @@ -93,10 +93,10 @@ class StoppingCriteriaTestCase(unittest.TestCase): input_ids, scores = self._get_tensors(5) criteria = MaxTimeCriteria(max_time=0.1) - self.assertFalse(criteria(input_ids, scores)) + self.assertFalse(all(criteria(input_ids, scores))) criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2) - self.assertTrue(criteria(input_ids, scores)) + self.assertTrue(all(criteria(input_ids, scores))) def test_validate_stopping_criteria(self): validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10)