Track each row separately for stopping criteria (#29116)
This commit is contained in:
committed by
GitHub
parent
ece1b62b93
commit
8f2f0f0f85
@@ -29,7 +29,8 @@ STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
|
|||||||
Additional stopping criteria specific kwargs.
|
Additional stopping criteria specific kwargs.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
`bool`. `False` indicates we should continue, `True` indicates we should stop.
|
`torch.BoolTensor`. (`torch.BoolTensor` of shape `(batch_size, 1)`), where `True` indicates we stop generation
|
||||||
|
for a particular row, `True` indicates we should continue.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -42,7 +43,7 @@ class StoppingCriteria(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
||||||
raise NotImplementedError("StoppingCriteria needs to be subclassed")
|
raise NotImplementedError("StoppingCriteria needs to be subclassed")
|
||||||
|
|
||||||
|
|
||||||
@@ -63,7 +64,7 @@ class MaxLengthCriteria(StoppingCriteria):
|
|||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
||||||
cur_len = input_ids.shape[-1]
|
cur_len = input_ids.shape[-1]
|
||||||
is_done = cur_len >= self.max_length
|
is_done = cur_len >= self.max_length
|
||||||
if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings:
|
if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings:
|
||||||
@@ -72,7 +73,7 @@ class MaxLengthCriteria(StoppingCriteria):
|
|||||||
f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe "
|
f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe "
|
||||||
"exceptions, performance degradation, or nothing at all."
|
"exceptions, performance degradation, or nothing at all."
|
||||||
)
|
)
|
||||||
return is_done
|
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device)
|
||||||
|
|
||||||
|
|
||||||
class MaxNewTokensCriteria(StoppingCriteria):
|
class MaxNewTokensCriteria(StoppingCriteria):
|
||||||
@@ -100,8 +101,9 @@ class MaxNewTokensCriteria(StoppingCriteria):
|
|||||||
self.max_length = start_length + max_new_tokens
|
self.max_length = start_length + max_new_tokens
|
||||||
|
|
||||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
||||||
return input_ids.shape[-1] >= self.max_length
|
is_done = input_ids.shape[-1] >= self.max_length
|
||||||
|
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device)
|
||||||
|
|
||||||
|
|
||||||
class MaxTimeCriteria(StoppingCriteria):
|
class MaxTimeCriteria(StoppingCriteria):
|
||||||
@@ -122,14 +124,18 @@ class MaxTimeCriteria(StoppingCriteria):
|
|||||||
self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp
|
self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp
|
||||||
|
|
||||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
||||||
return time.time() - self.initial_timestamp > self.max_time
|
is_done = time.time() - self.initial_timestamp > self.max_time
|
||||||
|
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device)
|
||||||
|
|
||||||
|
|
||||||
class StoppingCriteriaList(list):
|
class StoppingCriteriaList(list):
|
||||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
||||||
return any(criteria(input_ids, scores, **kwargs) for criteria in self)
|
is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device)
|
||||||
|
for criteria in self:
|
||||||
|
is_done = is_done | criteria(input_ids, scores, **kwargs)
|
||||||
|
return is_done
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_length(self) -> Optional[int]:
|
def max_length(self) -> Optional[int]:
|
||||||
|
|||||||
@@ -2195,11 +2195,9 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# stop when each sentence is finished
|
# stop when each sentence is finished
|
||||||
if unfinished_sequences.max() == 0:
|
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||||
this_peer_finished = True
|
|
||||||
|
|
||||||
# stop if we exceed the maximum length
|
if unfinished_sequences.max() == 0:
|
||||||
if stopping_criteria(input_ids, scores):
|
|
||||||
this_peer_finished = True
|
this_peer_finished = True
|
||||||
|
|
||||||
if this_peer_finished and not synced_gpus:
|
if this_peer_finished and not synced_gpus:
|
||||||
@@ -2478,14 +2476,12 @@ class GenerationMixin:
|
|||||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||||
|
|
||||||
# stop when each sentence is finished
|
# stop when each sentence is finished
|
||||||
if unfinished_sequences.max() == 0:
|
if unfinished_sequences.max() == 0:
|
||||||
this_peer_finished = True
|
this_peer_finished = True
|
||||||
|
|
||||||
# stop if we exceed the maximum length
|
|
||||||
if stopping_criteria(input_ids, scores):
|
|
||||||
this_peer_finished = True
|
|
||||||
|
|
||||||
if this_peer_finished and not synced_gpus:
|
if this_peer_finished and not synced_gpus:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -2772,14 +2768,12 @@ class GenerationMixin:
|
|||||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||||
|
|
||||||
# stop when each sentence is finished
|
# stop when each sentence is finished
|
||||||
if unfinished_sequences.max() == 0:
|
if unfinished_sequences.max() == 0:
|
||||||
this_peer_finished = True
|
this_peer_finished = True
|
||||||
|
|
||||||
# stop if we exceed the maximum length
|
|
||||||
if stopping_criteria(input_ids, scores):
|
|
||||||
this_peer_finished = True
|
|
||||||
|
|
||||||
if this_peer_finished and not synced_gpus:
|
if this_peer_finished and not synced_gpus:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -3169,7 +3163,7 @@ class GenerationMixin:
|
|||||||
# increase cur_len
|
# increase cur_len
|
||||||
cur_len = cur_len + 1
|
cur_len = cur_len + 1
|
||||||
|
|
||||||
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
|
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
|
||||||
if not synced_gpus:
|
if not synced_gpus:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
@@ -3516,7 +3510,7 @@ class GenerationMixin:
|
|||||||
# increase cur_len
|
# increase cur_len
|
||||||
cur_len = cur_len + 1
|
cur_len = cur_len + 1
|
||||||
|
|
||||||
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
|
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
|
||||||
if not synced_gpus:
|
if not synced_gpus:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
@@ -3912,7 +3906,7 @@ class GenerationMixin:
|
|||||||
# increase cur_len
|
# increase cur_len
|
||||||
cur_len = cur_len + 1
|
cur_len = cur_len + 1
|
||||||
|
|
||||||
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
|
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
|
||||||
if not synced_gpus:
|
if not synced_gpus:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
@@ -4267,7 +4261,7 @@ class GenerationMixin:
|
|||||||
# increase cur_len
|
# increase cur_len
|
||||||
cur_len = cur_len + 1
|
cur_len = cur_len + 1
|
||||||
|
|
||||||
if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores):
|
if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
|
||||||
if not synced_gpus:
|
if not synced_gpus:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
@@ -4657,14 +4651,12 @@ class GenerationMixin:
|
|||||||
.prod(dim=0)
|
.prod(dim=0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||||
|
|
||||||
# stop when each sentence is finished
|
# stop when each sentence is finished
|
||||||
if unfinished_sequences.max() == 0:
|
if unfinished_sequences.max() == 0:
|
||||||
this_peer_finished = True
|
this_peer_finished = True
|
||||||
|
|
||||||
# stop if we exceed the maximum length
|
|
||||||
if stopping_criteria(input_ids, scores):
|
|
||||||
this_peer_finished = True
|
|
||||||
|
|
||||||
if this_peer_finished and not synced_gpus:
|
if this_peer_finished and not synced_gpus:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
@@ -54,37 +54,37 @@ class StoppingCriteriaTestCase(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertFalse(criteria(input_ids, scores))
|
self.assertFalse(all(criteria(input_ids, scores)))
|
||||||
|
|
||||||
input_ids, scores = self._get_tensors(9)
|
input_ids, scores = self._get_tensors(9)
|
||||||
self.assertFalse(criteria(input_ids, scores))
|
self.assertFalse(all(criteria(input_ids, scores)))
|
||||||
|
|
||||||
input_ids, scores = self._get_tensors(10)
|
input_ids, scores = self._get_tensors(10)
|
||||||
self.assertTrue(criteria(input_ids, scores))
|
self.assertTrue(all(criteria(input_ids, scores)))
|
||||||
|
|
||||||
def test_max_length_criteria(self):
|
def test_max_length_criteria(self):
|
||||||
criteria = MaxLengthCriteria(max_length=10)
|
criteria = MaxLengthCriteria(max_length=10)
|
||||||
|
|
||||||
input_ids, scores = self._get_tensors(5)
|
input_ids, scores = self._get_tensors(5)
|
||||||
self.assertFalse(criteria(input_ids, scores))
|
self.assertFalse(all(criteria(input_ids, scores)))
|
||||||
|
|
||||||
input_ids, scores = self._get_tensors(9)
|
input_ids, scores = self._get_tensors(9)
|
||||||
self.assertFalse(criteria(input_ids, scores))
|
self.assertFalse(all(criteria(input_ids, scores)))
|
||||||
|
|
||||||
input_ids, scores = self._get_tensors(10)
|
input_ids, scores = self._get_tensors(10)
|
||||||
self.assertTrue(criteria(input_ids, scores))
|
self.assertTrue(all(criteria(input_ids, scores)))
|
||||||
|
|
||||||
def test_max_new_tokens_criteria(self):
|
def test_max_new_tokens_criteria(self):
|
||||||
criteria = MaxNewTokensCriteria(start_length=5, max_new_tokens=5)
|
criteria = MaxNewTokensCriteria(start_length=5, max_new_tokens=5)
|
||||||
|
|
||||||
input_ids, scores = self._get_tensors(5)
|
input_ids, scores = self._get_tensors(5)
|
||||||
self.assertFalse(criteria(input_ids, scores))
|
self.assertFalse(all(criteria(input_ids, scores)))
|
||||||
|
|
||||||
input_ids, scores = self._get_tensors(9)
|
input_ids, scores = self._get_tensors(9)
|
||||||
self.assertFalse(criteria(input_ids, scores))
|
self.assertFalse(all(criteria(input_ids, scores)))
|
||||||
|
|
||||||
input_ids, scores = self._get_tensors(10)
|
input_ids, scores = self._get_tensors(10)
|
||||||
self.assertTrue(criteria(input_ids, scores))
|
self.assertTrue(all(criteria(input_ids, scores)))
|
||||||
|
|
||||||
criteria_list = StoppingCriteriaList([criteria])
|
criteria_list = StoppingCriteriaList([criteria])
|
||||||
self.assertEqual(criteria_list.max_length, 10)
|
self.assertEqual(criteria_list.max_length, 10)
|
||||||
@@ -93,10 +93,10 @@ class StoppingCriteriaTestCase(unittest.TestCase):
|
|||||||
input_ids, scores = self._get_tensors(5)
|
input_ids, scores = self._get_tensors(5)
|
||||||
|
|
||||||
criteria = MaxTimeCriteria(max_time=0.1)
|
criteria = MaxTimeCriteria(max_time=0.1)
|
||||||
self.assertFalse(criteria(input_ids, scores))
|
self.assertFalse(all(criteria(input_ids, scores)))
|
||||||
|
|
||||||
criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2)
|
criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2)
|
||||||
self.assertTrue(criteria(input_ids, scores))
|
self.assertTrue(all(criteria(input_ids, scores)))
|
||||||
|
|
||||||
def test_validate_stopping_criteria(self):
|
def test_validate_stopping_criteria(self):
|
||||||
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10)
|
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10)
|
||||||
|
|||||||
Reference in New Issue
Block a user