Removed max_length from being mandatory within generate. (#11314)

* Removed `max_length` from being mandatory within `generate`.

- Moving on to fully using `StoppingCriteria` for `greedy` and `sample`
modes.
- `max_length` still used for `beam_search` and `group_beam_search`
(Follow up PR)
- Fixes a bug with MaxLengthStoppingCriteria (we should stop as soon a
we hit the max_length, the comparison needs to be or equal, that affects
the tests).
- Added options to use `logits_processor` and `stopping_criteria`
directly within `generate` function (so some users can define their own
`logits_processor` and `stopping_criteria`).
- Modified the backward compat tests to make sure we issue a warning.

* Fix `max_length` argument in `generate`.

* Moving validate to being functional.

- Renamed `smax_length` to `stoppping_max_length`.

* Removing `logits_processor` and `stopping_criteria` from `generate`
arguments.

* Deepcopy.

* Fix global variable name.
This commit is contained in:
Nicolas Patry
2021-04-21 11:56:45 +02:00
committed by GitHub
parent 95dab34d55
commit aad95c7cde
4 changed files with 122 additions and 80 deletions

View File

@@ -40,10 +40,10 @@ class StoppingCriteriaTestCase(unittest.TestCase):
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(10)
input_ids, scores = self._get_tensors(9)
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(11)
input_ids, scores = self._get_tensors(10)
self.assertTrue(criteria(input_ids, scores))
def test_max_length_criteria(self):
@@ -52,10 +52,10 @@ class StoppingCriteriaTestCase(unittest.TestCase):
input_ids, scores = self._get_tensors(5)
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(10)
input_ids, scores = self._get_tensors(9)
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(11)
input_ids, scores = self._get_tensors(10)
self.assertTrue(criteria(input_ids, scores))
def test_max_time_criteria(self):
@@ -73,7 +73,6 @@ class StoppingCriteriaTestCase(unittest.TestCase):
with self.assertWarns(UserWarning):
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 11)
stopping_criteria = StoppingCriteriaList()
validate_stopping_criteria(stopping_criteria, 11)
stopping_criteria = validate_stopping_criteria(StoppingCriteriaList(), 11)
self.assertEqual(len(stopping_criteria), 1)