Allow flexible generation params arg when checking pipeline specs (#37211)
* Allow flexible generation params arg * Trigger tests * Add docstring and rename js_generate to hub_generate
This commit is contained in:
@@ -930,6 +930,11 @@ def parse_args_from_docstring_by_indentation(docstring):
|
|||||||
|
|
||||||
|
|
||||||
def compare_pipeline_args_to_hub_spec(pipeline_class, hub_spec):
|
def compare_pipeline_args_to_hub_spec(pipeline_class, hub_spec):
|
||||||
|
"""
|
||||||
|
Compares the docstring of a pipeline class to the fields of the matching Hub input signature class to ensure that
|
||||||
|
they match. This guarantees that Transformers pipelines can be used in inference without needing to manually
|
||||||
|
refactor or rename inputs.
|
||||||
|
"""
|
||||||
ALLOWED_TRANSFORMERS_ONLY_ARGS = ["timeout"]
|
ALLOWED_TRANSFORMERS_ONLY_ARGS = ["timeout"]
|
||||||
|
|
||||||
docstring = inspect.getdoc(pipeline_class.__call__).strip()
|
docstring = inspect.getdoc(pipeline_class.__call__).strip()
|
||||||
@@ -937,16 +942,20 @@ def compare_pipeline_args_to_hub_spec(pipeline_class, hub_spec):
|
|||||||
hub_args = set(get_arg_names_from_hub_spec(hub_spec))
|
hub_args = set(get_arg_names_from_hub_spec(hub_spec))
|
||||||
|
|
||||||
# Special casing: We allow the name of this arg to differ
|
# Special casing: We allow the name of this arg to differ
|
||||||
js_generate_args = [js_arg for js_arg in hub_args if js_arg.startswith("generate")]
|
hub_generate_args = [
|
||||||
|
hub_arg for hub_arg in hub_args if hub_arg.startswith("generate") or hub_arg.startswith("generation")
|
||||||
|
]
|
||||||
docstring_generate_args = [
|
docstring_generate_args = [
|
||||||
docstring_arg for docstring_arg in docstring_args if docstring_arg.startswith("generate")
|
docstring_arg
|
||||||
|
for docstring_arg in docstring_args
|
||||||
|
if docstring_arg.startswith("generate") or docstring_arg.startswith("generation")
|
||||||
]
|
]
|
||||||
if (
|
if (
|
||||||
len(js_generate_args) == 1
|
len(hub_generate_args) == 1
|
||||||
and len(docstring_generate_args) == 1
|
and len(docstring_generate_args) == 1
|
||||||
and js_generate_args != docstring_generate_args
|
and hub_generate_args != docstring_generate_args
|
||||||
):
|
):
|
||||||
hub_args.remove(js_generate_args[0])
|
hub_args.remove(hub_generate_args[0])
|
||||||
docstring_args.remove(docstring_generate_args[0])
|
docstring_args.remove(docstring_generate_args[0])
|
||||||
|
|
||||||
# Special casing 2: We permit some transformers-only arguments that don't affect pipeline output
|
# Special casing 2: We permit some transformers-only arguments that don't affect pipeline output
|
||||||
|
|||||||
Reference in New Issue
Block a user