From 871ba71dfa04f9d37a4f32e1f962a1199a5cf51a Mon Sep 17 00:00:00 2001 From: FredericOdermatt <50372080+FredericOdermatt@users.noreply.github.com> Date: Tue, 27 Feb 2024 09:43:52 +0900 Subject: [PATCH] GenerationConfig validate both constraints and force_words_ids (#29163) GenerationConfig validate both options for constrained decoding: constraints and force_words_ids --- src/transformers/generation/configuration_utils.py | 8 ++++---- tests/generation/test_configuration_utils.py | 5 +++++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 87335b2667..f6d9c8f52c 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -482,11 +482,11 @@ class GenerationConfig(PushToHubMixin): # 3. detect incorrect paramaterization specific to advanced beam modes else: # constrained beam search - if self.constraints is not None: + if self.constraints is not None or self.force_words_ids is not None: constrained_wrong_parameter_msg = ( - "`constraints` is not `None`, triggering constrained beam search. However, `{flag_name}` is set " - "to `{flag_value}`, which is incompatible with this generation mode. Set `constraints=None` or " - "unset `{flag_name}` to continue." + fix_location + "one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. However, " + "`{flag_name}` is set to `{flag_value}`, which is incompatible with this generation mode. Set " + "`constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue." + fix_location ) if self.do_sample is True: raise ValueError( diff --git a/tests/generation/test_configuration_utils.py b/tests/generation/test_configuration_utils.py index 4ff9d35aa0..a86dd31440 100644 --- a/tests/generation/test_configuration_utils.py +++ b/tests/generation/test_configuration_utils.py @@ -156,6 +156,11 @@ class GenerationConfigTest(unittest.TestCase): # Impossible sets of contraints/parameters will raise an exception with self.assertRaises(ValueError): GenerationConfig(do_sample=False, num_beams=1, num_return_sequences=2) + with self.assertRaises(ValueError): + # dummy constraint + GenerationConfig(do_sample=True, num_beams=2, constraints=["dummy"]) + with self.assertRaises(ValueError): + GenerationConfig(do_sample=True, num_beams=2, force_words_ids=[[[1, 2, 3]]]) # Passing `generate()`-only flags to `validate` will raise an exception with self.assertRaises(ValueError):