diff --git a/tests/test_pipeline_mixin.py b/tests/test_pipeline_mixin.py index 94bc3d5fae..32969ab0d3 100644 --- a/tests/test_pipeline_mixin.py +++ b/tests/test_pipeline_mixin.py @@ -930,6 +930,11 @@ def parse_args_from_docstring_by_indentation(docstring): def compare_pipeline_args_to_hub_spec(pipeline_class, hub_spec): + """ + Compares the docstring of a pipeline class to the fields of the matching Hub input signature class to ensure that + they match. This guarantees that Transformers pipelines can be used in inference without needing to manually + refactor or rename inputs. + """ ALLOWED_TRANSFORMERS_ONLY_ARGS = ["timeout"] docstring = inspect.getdoc(pipeline_class.__call__).strip() @@ -937,16 +942,20 @@ def compare_pipeline_args_to_hub_spec(pipeline_class, hub_spec): hub_args = set(get_arg_names_from_hub_spec(hub_spec)) # Special casing: We allow the name of this arg to differ - js_generate_args = [js_arg for js_arg in hub_args if js_arg.startswith("generate")] + hub_generate_args = [ + hub_arg for hub_arg in hub_args if hub_arg.startswith("generate") or hub_arg.startswith("generation") + ] docstring_generate_args = [ - docstring_arg for docstring_arg in docstring_args if docstring_arg.startswith("generate") + docstring_arg + for docstring_arg in docstring_args + if docstring_arg.startswith("generate") or docstring_arg.startswith("generation") ] if ( - len(js_generate_args) == 1 + len(hub_generate_args) == 1 and len(docstring_generate_args) == 1 - and js_generate_args != docstring_generate_args + and hub_generate_args != docstring_generate_args ): - hub_args.remove(js_generate_args[0]) + hub_args.remove(hub_generate_args[0]) docstring_args.remove(docstring_generate_args[0]) # Special casing 2: We permit some transformers-only arguments that don't affect pipeline output