Generate: GenerationConfig throws an exception when generate args are passed (#27757)
This commit is contained in:
@@ -497,6 +497,24 @@ class GenerationConfig(PushToHubMixin):
|
||||
f"({self.num_beams})."
|
||||
)
|
||||
|
||||
# 5. check common issue: passing `generate` arguments inside the generation config
|
||||
generate_arguments = (
|
||||
"logits_processor",
|
||||
"stopping_criteria",
|
||||
"prefix_allowed_tokens_fn",
|
||||
"synced_gpus",
|
||||
"assistant_model",
|
||||
"streamer",
|
||||
"negative_prompt_ids",
|
||||
"negative_prompt_attention_mask",
|
||||
)
|
||||
for arg in generate_arguments:
|
||||
if hasattr(self, arg):
|
||||
raise ValueError(
|
||||
f"Argument `{arg}` is not a valid argument of `GenerationConfig`. It should be passed to "
|
||||
"`generate()` (or a pipeline) directly."
|
||||
)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
|
||||
@@ -120,6 +120,34 @@ class GenerationConfigTest(unittest.TestCase):
|
||||
self.assertEqual(loaded_config.do_sample, True)
|
||||
self.assertEqual(loaded_config.num_beams, 1) # default value
|
||||
|
||||
def test_validate(self):
|
||||
"""
|
||||
Tests that the `validate` method is working as expected. Note that `validate` is called at initialization time
|
||||
"""
|
||||
# Case 1: A correct configuration will not throw any warning
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
GenerationConfig()
|
||||
self.assertEqual(len(captured_warnings), 0)
|
||||
|
||||
# Case 2: Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
|
||||
# parameters with `do_sample=False`). May be escalated to an error in the future.
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
GenerationConfig(temperature=0.5)
|
||||
self.assertEqual(len(captured_warnings), 1)
|
||||
|
||||
# Case 3: Impossible sets of contraints/parameters will raise an exception
|
||||
with self.assertRaises(ValueError):
|
||||
GenerationConfig(num_return_sequences=2)
|
||||
|
||||
# Case 4: Passing `generate()`-only flags to `validate` will raise an exception
|
||||
with self.assertRaises(ValueError):
|
||||
GenerationConfig(logits_processor="foo")
|
||||
|
||||
# Case 5: Model-specific parameters will NOT raise an exception or a warning
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
GenerationConfig(foo="bar")
|
||||
self.assertEqual(len(captured_warnings), 0)
|
||||
|
||||
def test_refuse_to_save(self):
|
||||
"""Tests that we refuse to save a generation config that fails validation."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user