Removed max_length from being mandatory within generate. (#11314)
* Removed `max_length` from being mandatory within `generate`. - Moving on to fully using `StoppingCriteria` for `greedy` and `sample` modes. - `max_length` still used for `beam_search` and `group_beam_search` (Follow up PR) - Fixes a bug with MaxLengthStoppingCriteria (we should stop as soon a we hit the max_length, the comparison needs to be or equal, that affects the tests). - Added options to use `logits_processor` and `stopping_criteria` directly within `generate` function (so some users can define their own `logits_processor` and `stopping_criteria`). - Modified the backward compat tests to make sure we issue a warning. * Fix `max_length` argument in `generate`. * Moving validate to being functional. - Renamed `smax_length` to `stoppping_max_length`. * Removing `logits_processor` and `stopping_criteria` from `generate` arguments. * Deepcopy. * Fix global variable name.
This commit is contained in:
@@ -40,10 +40,10 @@ class StoppingCriteriaTestCase(unittest.TestCase):
|
||||
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
|
||||
input_ids, scores = self._get_tensors(10)
|
||||
input_ids, scores = self._get_tensors(9)
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
|
||||
input_ids, scores = self._get_tensors(11)
|
||||
input_ids, scores = self._get_tensors(10)
|
||||
self.assertTrue(criteria(input_ids, scores))
|
||||
|
||||
def test_max_length_criteria(self):
|
||||
@@ -52,10 +52,10 @@ class StoppingCriteriaTestCase(unittest.TestCase):
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
|
||||
input_ids, scores = self._get_tensors(10)
|
||||
input_ids, scores = self._get_tensors(9)
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
|
||||
input_ids, scores = self._get_tensors(11)
|
||||
input_ids, scores = self._get_tensors(10)
|
||||
self.assertTrue(criteria(input_ids, scores))
|
||||
|
||||
def test_max_time_criteria(self):
|
||||
@@ -73,7 +73,6 @@ class StoppingCriteriaTestCase(unittest.TestCase):
|
||||
with self.assertWarns(UserWarning):
|
||||
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 11)
|
||||
|
||||
stopping_criteria = StoppingCriteriaList()
|
||||
validate_stopping_criteria(stopping_criteria, 11)
|
||||
stopping_criteria = validate_stopping_criteria(StoppingCriteriaList(), 11)
|
||||
|
||||
self.assertEqual(len(stopping_criteria), 1)
|
||||
|
||||
Reference in New Issue
Block a user