Allow flexible generation params arg when checking pipeline specs (#37211)

* Allow flexible generation params arg

* Trigger tests

* Add docstring and rename js_generate to hub_generate
This commit is contained in:
Matt
2025-04-03 13:29:36 +01:00
committed by GitHub
parent afafb84b59
commit 782d7d945d

View File

@@ -930,6 +930,11 @@ def parse_args_from_docstring_by_indentation(docstring):
def compare_pipeline_args_to_hub_spec(pipeline_class, hub_spec):
"""
Compares the docstring of a pipeline class to the fields of the matching Hub input signature class to ensure that
they match. This guarantees that Transformers pipelines can be used in inference without needing to manually
refactor or rename inputs.
"""
ALLOWED_TRANSFORMERS_ONLY_ARGS = ["timeout"]
docstring = inspect.getdoc(pipeline_class.__call__).strip()
@@ -937,16 +942,20 @@ def compare_pipeline_args_to_hub_spec(pipeline_class, hub_spec):
hub_args = set(get_arg_names_from_hub_spec(hub_spec))
# Special casing: We allow the name of this arg to differ
js_generate_args = [js_arg for js_arg in hub_args if js_arg.startswith("generate")]
hub_generate_args = [
hub_arg for hub_arg in hub_args if hub_arg.startswith("generate") or hub_arg.startswith("generation")
]
docstring_generate_args = [
docstring_arg for docstring_arg in docstring_args if docstring_arg.startswith("generate")
docstring_arg
for docstring_arg in docstring_args
if docstring_arg.startswith("generate") or docstring_arg.startswith("generation")
]
if (
len(js_generate_args) == 1
len(hub_generate_args) == 1
and len(docstring_generate_args) == 1
and js_generate_args != docstring_generate_args
and hub_generate_args != docstring_generate_args
):
hub_args.remove(js_generate_args[0])
hub_args.remove(hub_generate_args[0])
docstring_args.remove(docstring_generate_args[0])
# Special casing 2: We permit some transformers-only arguments that don't affect pipeline output