Removed max_length from being mandatory within generate. (#11314)

* Removed `max_length` from being mandatory within `generate`.

- Moving on to fully using `StoppingCriteria` for `greedy` and `sample`
modes.
- `max_length` still used for `beam_search` and `group_beam_search`
(Follow up PR)
- Fixes a bug with MaxLengthStoppingCriteria (we should stop as soon a
we hit the max_length, the comparison needs to be or equal, that affects
the tests).
- Added options to use `logits_processor` and `stopping_criteria`
directly within `generate` function (so some users can define their own
`logits_processor` and `stopping_criteria`).
- Modified the backward compat tests to make sure we issue a warning.

* Fix `max_length` argument in `generate`.

* Moving validate to being functional.

- Renamed `smax_length` to `stoppping_max_length`.

* Removing `logits_processor` and `stopping_criteria` from `generate`
arguments.

* Deepcopy.

* Fix global variable name.
This commit is contained in:
Nicolas Patry
2021-04-21 11:56:45 +02:00
committed by GitHub
parent 95dab34d55
commit aad95c7cde
4 changed files with 122 additions and 80 deletions

View File

@@ -1358,13 +1358,14 @@ class GenerationIntegrationTests(unittest.TestCase):
bos_token_id=bart_model.config.bos_token_id,
)
bart_model.greedy_search(
input_ids,
max_length=max_length,
pad_token_id=bart_model.config.pad_token_id,
eos_token_id=bart_model.config.eos_token_id,
**model_kwargs,
)
with self.assertWarns(UserWarning):
bart_model.greedy_search(
input_ids,
max_length=max_length,
pad_token_id=bart_model.config.pad_token_id,
eos_token_id=bart_model.config.eos_token_id,
**model_kwargs,
)
def test_max_length_backward_compat_sample(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
@@ -1381,13 +1382,14 @@ class GenerationIntegrationTests(unittest.TestCase):
bos_token_id=bart_model.config.bos_token_id,
)
with torch.no_grad():
bart_model.sample(
input_ids,
max_length=max_length,
pad_token_id=bart_model.config.pad_token_id,
eos_token_id=bart_model.config.eos_token_id,
**model_kwargs,
)
with self.assertWarns(UserWarning):
bart_model.sample(
input_ids,
max_length=max_length,
pad_token_id=bart_model.config.pad_token_id,
eos_token_id=bart_model.config.eos_token_id,
**model_kwargs,
)
def test_max_length_backward_compat_beam_search(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
@@ -1413,9 +1415,10 @@ class GenerationIntegrationTests(unittest.TestCase):
num_beams=num_beams,
device=torch_device,
)
_ = bart_model.beam_search(
input_ids, num_beams=num_beams, max_length=max_length, beam_scorer=beam_scorer, **model_kwargs
)
with self.assertWarns(UserWarning):
_ = bart_model.beam_search(
input_ids, num_beams=num_beams, max_length=max_length, beam_scorer=beam_scorer, **model_kwargs
)
def test_max_length_backward_compat_group_beam_search(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
@@ -1445,9 +1448,10 @@ class GenerationIntegrationTests(unittest.TestCase):
num_beam_hyps_to_keep=num_return_sequences,
num_beam_groups=num_beam_groups,
)
bart_model.group_beam_search(
input_ids, diverse_beam_scorer, num_beams=num_beams, max_length=max_length, **model_kwargs
)
with self.assertWarns(UserWarning):
bart_model.group_beam_search(
input_ids, diverse_beam_scorer, num_beams=num_beams, max_length=max_length, **model_kwargs
)
def test_max_length_warning_if_different(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""