Removed max_length from being mandatory within generate. (#11314)
* Removed `max_length` from being mandatory within `generate`. - Moving on to fully using `StoppingCriteria` for `greedy` and `sample` modes. - `max_length` still used for `beam_search` and `group_beam_search` (Follow up PR) - Fixes a bug with MaxLengthStoppingCriteria (we should stop as soon a we hit the max_length, the comparison needs to be or equal, that affects the tests). - Added options to use `logits_processor` and `stopping_criteria` directly within `generate` function (so some users can define their own `logits_processor` and `stopping_criteria`). - Modified the backward compat tests to make sure we issue a warning. * Fix `max_length` argument in `generate`. * Moving validate to being functional. - Renamed `smax_length` to `stoppping_max_length`. * Removing `logits_processor` and `stopping_criteria` from `generate` arguments. * Deepcopy. * Fix global variable name.
This commit is contained in:
@@ -1358,13 +1358,14 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
|
||||
bart_model.greedy_search(
|
||||
input_ids,
|
||||
max_length=max_length,
|
||||
pad_token_id=bart_model.config.pad_token_id,
|
||||
eos_token_id=bart_model.config.eos_token_id,
|
||||
**model_kwargs,
|
||||
)
|
||||
with self.assertWarns(UserWarning):
|
||||
bart_model.greedy_search(
|
||||
input_ids,
|
||||
max_length=max_length,
|
||||
pad_token_id=bart_model.config.pad_token_id,
|
||||
eos_token_id=bart_model.config.eos_token_id,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
def test_max_length_backward_compat_sample(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
@@ -1381,13 +1382,14 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
with torch.no_grad():
|
||||
bart_model.sample(
|
||||
input_ids,
|
||||
max_length=max_length,
|
||||
pad_token_id=bart_model.config.pad_token_id,
|
||||
eos_token_id=bart_model.config.eos_token_id,
|
||||
**model_kwargs,
|
||||
)
|
||||
with self.assertWarns(UserWarning):
|
||||
bart_model.sample(
|
||||
input_ids,
|
||||
max_length=max_length,
|
||||
pad_token_id=bart_model.config.pad_token_id,
|
||||
eos_token_id=bart_model.config.eos_token_id,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
def test_max_length_backward_compat_beam_search(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
@@ -1413,9 +1415,10 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
num_beams=num_beams,
|
||||
device=torch_device,
|
||||
)
|
||||
_ = bart_model.beam_search(
|
||||
input_ids, num_beams=num_beams, max_length=max_length, beam_scorer=beam_scorer, **model_kwargs
|
||||
)
|
||||
with self.assertWarns(UserWarning):
|
||||
_ = bart_model.beam_search(
|
||||
input_ids, num_beams=num_beams, max_length=max_length, beam_scorer=beam_scorer, **model_kwargs
|
||||
)
|
||||
|
||||
def test_max_length_backward_compat_group_beam_search(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
@@ -1445,9 +1448,10 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
num_beam_hyps_to_keep=num_return_sequences,
|
||||
num_beam_groups=num_beam_groups,
|
||||
)
|
||||
bart_model.group_beam_search(
|
||||
input_ids, diverse_beam_scorer, num_beams=num_beams, max_length=max_length, **model_kwargs
|
||||
)
|
||||
with self.assertWarns(UserWarning):
|
||||
bart_model.group_beam_search(
|
||||
input_ids, diverse_beam_scorer, num_beams=num_beams, max_length=max_length, **model_kwargs
|
||||
)
|
||||
|
||||
def test_max_length_warning_if_different(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
|
||||
Reference in New Issue
Block a user