Make audio classification pipeline spec-compliant and add test (#33730)
* Make audio classification pipeline spec-compliant and add test * Check that test actually running in CI * Try a different pipeline for the CI * Move the test so it gets triggered * Move it again, this time into task_tests! * make fixup * indentation fix * comment * Move everything from testing_utils to test_pipeline_mixin * Add output testing too * revert small diff with main * make fixup * Clarify comment * Update tests/pipelines/test_pipelines_audio_classification.py Co-authored-by: Lucain <lucainp@gmail.com> * Update tests/test_pipeline_mixin.py Co-authored-by: Lucain <lucainp@gmail.com> * Rename function and js_args -> hub_args * Cleanup the spec recursion * Check keys for all outputs --------- Co-authored-by: Lucain <lucainp@gmail.com>
This commit is contained in:
@@ -13,8 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from dataclasses import fields
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import AudioClassificationOutputElement
|
||||
|
||||
from transformers import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
|
||||
from transformers.pipelines import AudioClassificationPipeline, pipeline
|
||||
@@ -66,6 +68,11 @@ class AudioClassificationPipelineTests(unittest.TestCase):
|
||||
|
||||
self.run_torchaudio(audio_classifier)
|
||||
|
||||
spec_output_keys = {field.name for field in fields(AudioClassificationOutputElement)}
|
||||
for single_output in output:
|
||||
output_keys = set(single_output.keys())
|
||||
self.assertEqual(spec_output_keys, output_keys, msg="Pipeline output keys do not match HF Hub spec!")
|
||||
|
||||
@require_torchaudio
|
||||
def run_torchaudio(self, audio_classifier):
|
||||
import datasets
|
||||
|
||||
@@ -14,12 +14,20 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import unittest
|
||||
from dataclasses import fields, is_dataclass
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
from typing import get_args
|
||||
|
||||
from huggingface_hub import AudioClassificationInput
|
||||
|
||||
from transformers.pipelines import AudioClassificationPipeline
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
require_decord,
|
||||
@@ -92,6 +100,12 @@ pipeline_test_mapping = {
|
||||
"zero-shot-object-detection": {"test": ZeroShotObjectDetectionPipelineTests},
|
||||
}
|
||||
|
||||
task_to_pipeline_and_spec_mapping = {
|
||||
# Adding a task to this list will cause its pipeline input signature to be checked against the corresponding
|
||||
# task spec in the HF Hub
|
||||
"audio-classification": (AudioClassificationPipeline, AudioClassificationInput),
|
||||
}
|
||||
|
||||
for task, task_info in pipeline_test_mapping.items():
|
||||
test = task_info["test"]
|
||||
task_info["mapping"] = {
|
||||
@@ -175,6 +189,9 @@ class PipelineTesterMixin:
|
||||
self.run_model_pipeline_tests(
|
||||
task, repo_name, model_architecture, tokenizer_names, processor_names, commit, torch_dtype
|
||||
)
|
||||
if task in task_to_pipeline_and_spec_mapping:
|
||||
pipeline, hub_spec = task_to_pipeline_and_spec_mapping[task]
|
||||
compare_pipeline_args_to_hub_spec(pipeline, hub_spec)
|
||||
|
||||
def run_model_pipeline_tests(
|
||||
self, task, repo_name, model_architecture, tokenizer_names, processor_names, commit, torch_dtype="float32"
|
||||
@@ -685,3 +702,87 @@ def validate_test_components(test_case, task, model, tokenizer, processor):
|
||||
raise ValueError(
|
||||
"Could not determine `vocab_size` from model configuration while `tokenizer` is not `None`."
|
||||
)
|
||||
|
||||
|
||||
def get_arg_names_from_hub_spec(hub_spec, first_level=True):
|
||||
# This util is used in pipeline tests, to verify that a pipeline's documented arguments
|
||||
# match the Hub specification for that task
|
||||
arg_names = []
|
||||
for field in fields(hub_spec):
|
||||
# Recurse into nested fields, but max one level
|
||||
if is_dataclass(field.type):
|
||||
arg_names.extend([field.name for field in fields(field.type)])
|
||||
continue
|
||||
# Next, catch nested fields that are part of a Union[], which is usually caused by Optional[]
|
||||
for param_type in get_args(field.type):
|
||||
if is_dataclass(param_type):
|
||||
# Again, recurse into nested fields, but max one level
|
||||
arg_names.extend([field.name for field in fields(param_type)])
|
||||
break
|
||||
else:
|
||||
# Finally, this line triggers if it's not a nested field
|
||||
arg_names.append(field.name)
|
||||
return arg_names
|
||||
|
||||
|
||||
def parse_args_from_docstring_by_indentation(docstring):
|
||||
# This util is used in pipeline tests, to extract the argument names from a google-format docstring
|
||||
# to compare them against the Hub specification for that task. It uses indentation levels as a primary
|
||||
# source of truth, so these have to be correct!
|
||||
docstring = dedent(docstring)
|
||||
lines_by_indent = [
|
||||
(len(line) - len(line.lstrip()), line.strip()) for line in docstring.split("\n") if line.strip()
|
||||
]
|
||||
args_lineno = None
|
||||
args_indent = None
|
||||
args_end = None
|
||||
for lineno, (indent, line) in enumerate(lines_by_indent):
|
||||
if line == "Args:":
|
||||
args_lineno = lineno
|
||||
args_indent = indent
|
||||
continue
|
||||
elif args_lineno is not None and indent == args_indent:
|
||||
args_end = lineno
|
||||
break
|
||||
if args_lineno is None:
|
||||
raise ValueError("No args block to parse!")
|
||||
elif args_end is None:
|
||||
args_block = lines_by_indent[args_lineno + 1 :]
|
||||
else:
|
||||
args_block = lines_by_indent[args_lineno + 1 : args_end]
|
||||
outer_indent_level = min(line[0] for line in args_block)
|
||||
outer_lines = [line for line in args_block if line[0] == outer_indent_level]
|
||||
arg_names = [re.match(r"(\w+)\W", line[1]).group(1) for line in outer_lines]
|
||||
return arg_names
|
||||
|
||||
|
||||
def compare_pipeline_args_to_hub_spec(pipeline_class, hub_spec):
|
||||
docstring = inspect.getdoc(pipeline_class.__call__).strip()
|
||||
docstring_args = set(parse_args_from_docstring_by_indentation(docstring))
|
||||
hub_args = set(get_arg_names_from_hub_spec(hub_spec))
|
||||
|
||||
# Special casing: We allow the name of this arg to differ
|
||||
js_generate_args = [js_arg for js_arg in hub_args if js_arg.startswith("generate")]
|
||||
docstring_generate_args = [
|
||||
docstring_arg for docstring_arg in docstring_args if docstring_arg.startswith("generate")
|
||||
]
|
||||
if (
|
||||
len(js_generate_args) == 1
|
||||
and len(docstring_generate_args) == 1
|
||||
and js_generate_args != docstring_generate_args
|
||||
):
|
||||
hub_args.remove(js_generate_args[0])
|
||||
docstring_args.remove(docstring_generate_args[0])
|
||||
|
||||
if hub_args != docstring_args:
|
||||
error = [f"{pipeline_class.__name__} differs from JS spec {hub_spec.__name__}"]
|
||||
matching_args = hub_args & docstring_args
|
||||
huggingface_hub_only = hub_args - docstring_args
|
||||
transformers_only = docstring_args - hub_args
|
||||
if matching_args:
|
||||
error.append(f"Matching args: {matching_args}")
|
||||
if huggingface_hub_only:
|
||||
error.append(f"Huggingface Hub only: {huggingface_hub_only}")
|
||||
if transformers_only:
|
||||
error.append(f"Transformers only: {transformers_only}")
|
||||
raise ValueError("\n".join(error))
|
||||
|
||||
Reference in New Issue
Block a user