Make audio classification pipeline spec-compliant and add test (#33730)

* Make audio classification pipeline spec-compliant and add test

* Check that test actually running in CI

* Try a different pipeline for the CI

* Move the test so it gets triggered

* Move it again, this time into task_tests!

* make fixup

* indentation fix

* comment

* Move everything from testing_utils to test_pipeline_mixin

* Add output testing too

* revert small diff with main

* make fixup

* Clarify comment

* Update tests/pipelines/test_pipelines_audio_classification.py

Co-authored-by: Lucain <lucainp@gmail.com>

* Update tests/test_pipeline_mixin.py

Co-authored-by: Lucain <lucainp@gmail.com>

* Rename function and js_args -> hub_args

* Cleanup the spec recursion

* Check keys for all outputs

---------

Co-authored-by: Lucain <lucainp@gmail.com>
This commit is contained in:
Matt
2024-09-27 17:01:06 +01:00
committed by GitHub
parent 4973fc5769
commit d3821c4aed
3 changed files with 130 additions and 3 deletions

View File

@@ -126,6 +126,11 @@ class AudioClassificationPipeline(Pipeline):
The number of top labels that will be returned by the pipeline. If the provided number is `None` or
higher than the number of labels available in the model configuration, it will default to the number of
labels.
function_to_apply(`str`, *optional*, defaults to "softmax"):
The function to apply to the model output. By default, the pipeline will apply the softmax function to
the output of the model. Valid options: ["softmax", "sigmoid", "none"]. Note that passing Python's
built-in `None` will default to "softmax", so you need to pass the string "none" to disable any
post-processing.
Return:
A list of `dict` with the following keys:
@@ -135,13 +140,22 @@ class AudioClassificationPipeline(Pipeline):
"""
return super().__call__(inputs, **kwargs)
def _sanitize_parameters(self, top_k=None, **kwargs):
def _sanitize_parameters(self, top_k=None, function_to_apply=None, **kwargs):
# No parameters on this pipeline right now
postprocess_params = {}
if top_k is not None:
if top_k > self.model.config.num_labels:
top_k = self.model.config.num_labels
postprocess_params["top_k"] = top_k
if function_to_apply is not None:
if function_to_apply not in ["softmax", "sigmoid", "none"]:
raise ValueError(
f"Invalid value for `function_to_apply`: {function_to_apply}. "
"Valid options are ['softmax', 'sigmoid', 'none']"
)
postprocess_params["function_to_apply"] = function_to_apply
else:
postprocess_params["function_to_apply"] = "softmax"
return {}, {}, postprocess_params
def preprocess(self, inputs):
@@ -203,8 +217,13 @@ class AudioClassificationPipeline(Pipeline):
model_outputs = self.model(**model_inputs)
return model_outputs
def postprocess(self, model_outputs, top_k=5):
def postprocess(self, model_outputs, top_k=5, function_to_apply="softmax"):
if function_to_apply == "softmax":
probs = model_outputs.logits[0].softmax(-1)
elif function_to_apply == "sigmoid":
probs = model_outputs.logits[0].sigmoid()
else:
probs = model_outputs.logits[0]
scores, ids = probs.topk(top_k)
scores = scores.tolist()

View File

@@ -13,8 +13,10 @@
# limitations under the License.
import unittest
from dataclasses import fields
import numpy as np
from huggingface_hub import AudioClassificationOutputElement
from transformers import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
from transformers.pipelines import AudioClassificationPipeline, pipeline
@@ -66,6 +68,11 @@ class AudioClassificationPipelineTests(unittest.TestCase):
self.run_torchaudio(audio_classifier)
spec_output_keys = {field.name for field in fields(AudioClassificationOutputElement)}
for single_output in output:
output_keys = set(single_output.keys())
self.assertEqual(spec_output_keys, output_keys, msg="Pipeline output keys do not match HF Hub spec!")
@require_torchaudio
def run_torchaudio(self, audio_classifier):
import datasets

View File

@@ -14,12 +14,20 @@
# limitations under the License.
import copy
import inspect
import json
import os
import random
import re
import unittest
from dataclasses import fields, is_dataclass
from pathlib import Path
from textwrap import dedent
from typing import get_args
from huggingface_hub import AudioClassificationInput
from transformers.pipelines import AudioClassificationPipeline
from transformers.testing_utils import (
is_pipeline_test,
require_decord,
@@ -92,6 +100,12 @@ pipeline_test_mapping = {
"zero-shot-object-detection": {"test": ZeroShotObjectDetectionPipelineTests},
}
task_to_pipeline_and_spec_mapping = {
# Adding a task to this list will cause its pipeline input signature to be checked against the corresponding
# task spec in the HF Hub
"audio-classification": (AudioClassificationPipeline, AudioClassificationInput),
}
for task, task_info in pipeline_test_mapping.items():
test = task_info["test"]
task_info["mapping"] = {
@@ -175,6 +189,9 @@ class PipelineTesterMixin:
self.run_model_pipeline_tests(
task, repo_name, model_architecture, tokenizer_names, processor_names, commit, torch_dtype
)
if task in task_to_pipeline_and_spec_mapping:
pipeline, hub_spec = task_to_pipeline_and_spec_mapping[task]
compare_pipeline_args_to_hub_spec(pipeline, hub_spec)
def run_model_pipeline_tests(
self, task, repo_name, model_architecture, tokenizer_names, processor_names, commit, torch_dtype="float32"
@@ -685,3 +702,87 @@ def validate_test_components(test_case, task, model, tokenizer, processor):
raise ValueError(
"Could not determine `vocab_size` from model configuration while `tokenizer` is not `None`."
)
def get_arg_names_from_hub_spec(hub_spec, first_level=True):
# This util is used in pipeline tests, to verify that a pipeline's documented arguments
# match the Hub specification for that task
arg_names = []
for field in fields(hub_spec):
# Recurse into nested fields, but max one level
if is_dataclass(field.type):
arg_names.extend([field.name for field in fields(field.type)])
continue
# Next, catch nested fields that are part of a Union[], which is usually caused by Optional[]
for param_type in get_args(field.type):
if is_dataclass(param_type):
# Again, recurse into nested fields, but max one level
arg_names.extend([field.name for field in fields(param_type)])
break
else:
# Finally, this line triggers if it's not a nested field
arg_names.append(field.name)
return arg_names
def parse_args_from_docstring_by_indentation(docstring):
# This util is used in pipeline tests, to extract the argument names from a google-format docstring
# to compare them against the Hub specification for that task. It uses indentation levels as a primary
# source of truth, so these have to be correct!
docstring = dedent(docstring)
lines_by_indent = [
(len(line) - len(line.lstrip()), line.strip()) for line in docstring.split("\n") if line.strip()
]
args_lineno = None
args_indent = None
args_end = None
for lineno, (indent, line) in enumerate(lines_by_indent):
if line == "Args:":
args_lineno = lineno
args_indent = indent
continue
elif args_lineno is not None and indent == args_indent:
args_end = lineno
break
if args_lineno is None:
raise ValueError("No args block to parse!")
elif args_end is None:
args_block = lines_by_indent[args_lineno + 1 :]
else:
args_block = lines_by_indent[args_lineno + 1 : args_end]
outer_indent_level = min(line[0] for line in args_block)
outer_lines = [line for line in args_block if line[0] == outer_indent_level]
arg_names = [re.match(r"(\w+)\W", line[1]).group(1) for line in outer_lines]
return arg_names
def compare_pipeline_args_to_hub_spec(pipeline_class, hub_spec):
docstring = inspect.getdoc(pipeline_class.__call__).strip()
docstring_args = set(parse_args_from_docstring_by_indentation(docstring))
hub_args = set(get_arg_names_from_hub_spec(hub_spec))
# Special casing: We allow the name of this arg to differ
js_generate_args = [js_arg for js_arg in hub_args if js_arg.startswith("generate")]
docstring_generate_args = [
docstring_arg for docstring_arg in docstring_args if docstring_arg.startswith("generate")
]
if (
len(js_generate_args) == 1
and len(docstring_generate_args) == 1
and js_generate_args != docstring_generate_args
):
hub_args.remove(js_generate_args[0])
docstring_args.remove(docstring_generate_args[0])
if hub_args != docstring_args:
error = [f"{pipeline_class.__name__} differs from JS spec {hub_spec.__name__}"]
matching_args = hub_args & docstring_args
huggingface_hub_only = hub_args - docstring_args
transformers_only = docstring_args - hub_args
if matching_args:
error.append(f"Matching args: {matching_args}")
if huggingface_hub_only:
error.append(f"Huggingface Hub only: {huggingface_hub_only}")
if transformers_only:
error.append(f"Transformers only: {transformers_only}")
raise ValueError("\n".join(error))