Make audio classification pipeline spec-compliant and add test (#33730)

* Make audio classification pipeline spec-compliant and add test

* Check that test actually running in CI

* Try a different pipeline for the CI

* Move the test so it gets triggered

* Move it again, this time into task_tests!

* make fixup

* indentation fix

* comment

* Move everything from testing_utils to test_pipeline_mixin

* Add output testing too

* revert small diff with main

* make fixup

* Clarify comment

* Update tests/pipelines/test_pipelines_audio_classification.py

Co-authored-by: Lucain <lucainp@gmail.com>

* Update tests/test_pipeline_mixin.py

Co-authored-by: Lucain <lucainp@gmail.com>

* Rename function and js_args -> hub_args

* Cleanup the spec recursion

* Check keys for all outputs

---------

Co-authored-by: Lucain <lucainp@gmail.com>
This commit is contained in:
Matt
2024-09-27 17:01:06 +01:00
committed by GitHub
parent 4973fc5769
commit d3821c4aed
3 changed files with 130 additions and 3 deletions

View File

@@ -126,6 +126,11 @@ class AudioClassificationPipeline(Pipeline):
The number of top labels that will be returned by the pipeline. If the provided number is `None` or The number of top labels that will be returned by the pipeline. If the provided number is `None` or
higher than the number of labels available in the model configuration, it will default to the number of higher than the number of labels available in the model configuration, it will default to the number of
labels. labels.
function_to_apply(`str`, *optional*, defaults to "softmax"):
The function to apply to the model output. By default, the pipeline will apply the softmax function to
the output of the model. Valid options: ["softmax", "sigmoid", "none"]. Note that passing Python's
built-in `None` will default to "softmax", so you need to pass the string "none" to disable any
post-processing.
Return: Return:
A list of `dict` with the following keys: A list of `dict` with the following keys:
@@ -135,13 +140,22 @@ class AudioClassificationPipeline(Pipeline):
""" """
return super().__call__(inputs, **kwargs) return super().__call__(inputs, **kwargs)
def _sanitize_parameters(self, top_k=None, **kwargs): def _sanitize_parameters(self, top_k=None, function_to_apply=None, **kwargs):
# No parameters on this pipeline right now # No parameters on this pipeline right now
postprocess_params = {} postprocess_params = {}
if top_k is not None: if top_k is not None:
if top_k > self.model.config.num_labels: if top_k > self.model.config.num_labels:
top_k = self.model.config.num_labels top_k = self.model.config.num_labels
postprocess_params["top_k"] = top_k postprocess_params["top_k"] = top_k
if function_to_apply is not None:
if function_to_apply not in ["softmax", "sigmoid", "none"]:
raise ValueError(
f"Invalid value for `function_to_apply`: {function_to_apply}. "
"Valid options are ['softmax', 'sigmoid', 'none']"
)
postprocess_params["function_to_apply"] = function_to_apply
else:
postprocess_params["function_to_apply"] = "softmax"
return {}, {}, postprocess_params return {}, {}, postprocess_params
def preprocess(self, inputs): def preprocess(self, inputs):
@@ -203,8 +217,13 @@ class AudioClassificationPipeline(Pipeline):
model_outputs = self.model(**model_inputs) model_outputs = self.model(**model_inputs)
return model_outputs return model_outputs
def postprocess(self, model_outputs, top_k=5): def postprocess(self, model_outputs, top_k=5, function_to_apply="softmax"):
probs = model_outputs.logits[0].softmax(-1) if function_to_apply == "softmax":
probs = model_outputs.logits[0].softmax(-1)
elif function_to_apply == "sigmoid":
probs = model_outputs.logits[0].sigmoid()
else:
probs = model_outputs.logits[0]
scores, ids = probs.topk(top_k) scores, ids = probs.topk(top_k)
scores = scores.tolist() scores = scores.tolist()

View File

@@ -13,8 +13,10 @@
# limitations under the License. # limitations under the License.
import unittest import unittest
from dataclasses import fields
import numpy as np import numpy as np
from huggingface_hub import AudioClassificationOutputElement
from transformers import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING from transformers import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
from transformers.pipelines import AudioClassificationPipeline, pipeline from transformers.pipelines import AudioClassificationPipeline, pipeline
@@ -66,6 +68,11 @@ class AudioClassificationPipelineTests(unittest.TestCase):
self.run_torchaudio(audio_classifier) self.run_torchaudio(audio_classifier)
spec_output_keys = {field.name for field in fields(AudioClassificationOutputElement)}
for single_output in output:
output_keys = set(single_output.keys())
self.assertEqual(spec_output_keys, output_keys, msg="Pipeline output keys do not match HF Hub spec!")
@require_torchaudio @require_torchaudio
def run_torchaudio(self, audio_classifier): def run_torchaudio(self, audio_classifier):
import datasets import datasets

View File

@@ -14,12 +14,20 @@
# limitations under the License. # limitations under the License.
import copy import copy
import inspect
import json import json
import os import os
import random import random
import re
import unittest import unittest
from dataclasses import fields, is_dataclass
from pathlib import Path from pathlib import Path
from textwrap import dedent
from typing import get_args
from huggingface_hub import AudioClassificationInput
from transformers.pipelines import AudioClassificationPipeline
from transformers.testing_utils import ( from transformers.testing_utils import (
is_pipeline_test, is_pipeline_test,
require_decord, require_decord,
@@ -92,6 +100,12 @@ pipeline_test_mapping = {
"zero-shot-object-detection": {"test": ZeroShotObjectDetectionPipelineTests}, "zero-shot-object-detection": {"test": ZeroShotObjectDetectionPipelineTests},
} }
task_to_pipeline_and_spec_mapping = {
# Adding a task to this list will cause its pipeline input signature to be checked against the corresponding
# task spec in the HF Hub
"audio-classification": (AudioClassificationPipeline, AudioClassificationInput),
}
for task, task_info in pipeline_test_mapping.items(): for task, task_info in pipeline_test_mapping.items():
test = task_info["test"] test = task_info["test"]
task_info["mapping"] = { task_info["mapping"] = {
@@ -175,6 +189,9 @@ class PipelineTesterMixin:
self.run_model_pipeline_tests( self.run_model_pipeline_tests(
task, repo_name, model_architecture, tokenizer_names, processor_names, commit, torch_dtype task, repo_name, model_architecture, tokenizer_names, processor_names, commit, torch_dtype
) )
if task in task_to_pipeline_and_spec_mapping:
pipeline, hub_spec = task_to_pipeline_and_spec_mapping[task]
compare_pipeline_args_to_hub_spec(pipeline, hub_spec)
def run_model_pipeline_tests( def run_model_pipeline_tests(
self, task, repo_name, model_architecture, tokenizer_names, processor_names, commit, torch_dtype="float32" self, task, repo_name, model_architecture, tokenizer_names, processor_names, commit, torch_dtype="float32"
@@ -685,3 +702,87 @@ def validate_test_components(test_case, task, model, tokenizer, processor):
raise ValueError( raise ValueError(
"Could not determine `vocab_size` from model configuration while `tokenizer` is not `None`." "Could not determine `vocab_size` from model configuration while `tokenizer` is not `None`."
) )
def get_arg_names_from_hub_spec(hub_spec, first_level=True):
# This util is used in pipeline tests, to verify that a pipeline's documented arguments
# match the Hub specification for that task
arg_names = []
for field in fields(hub_spec):
# Recurse into nested fields, but max one level
if is_dataclass(field.type):
arg_names.extend([field.name for field in fields(field.type)])
continue
# Next, catch nested fields that are part of a Union[], which is usually caused by Optional[]
for param_type in get_args(field.type):
if is_dataclass(param_type):
# Again, recurse into nested fields, but max one level
arg_names.extend([field.name for field in fields(param_type)])
break
else:
# Finally, this line triggers if it's not a nested field
arg_names.append(field.name)
return arg_names
def parse_args_from_docstring_by_indentation(docstring):
# This util is used in pipeline tests, to extract the argument names from a google-format docstring
# to compare them against the Hub specification for that task. It uses indentation levels as a primary
# source of truth, so these have to be correct!
docstring = dedent(docstring)
lines_by_indent = [
(len(line) - len(line.lstrip()), line.strip()) for line in docstring.split("\n") if line.strip()
]
args_lineno = None
args_indent = None
args_end = None
for lineno, (indent, line) in enumerate(lines_by_indent):
if line == "Args:":
args_lineno = lineno
args_indent = indent
continue
elif args_lineno is not None and indent == args_indent:
args_end = lineno
break
if args_lineno is None:
raise ValueError("No args block to parse!")
elif args_end is None:
args_block = lines_by_indent[args_lineno + 1 :]
else:
args_block = lines_by_indent[args_lineno + 1 : args_end]
outer_indent_level = min(line[0] for line in args_block)
outer_lines = [line for line in args_block if line[0] == outer_indent_level]
arg_names = [re.match(r"(\w+)\W", line[1]).group(1) for line in outer_lines]
return arg_names
def compare_pipeline_args_to_hub_spec(pipeline_class, hub_spec):
docstring = inspect.getdoc(pipeline_class.__call__).strip()
docstring_args = set(parse_args_from_docstring_by_indentation(docstring))
hub_args = set(get_arg_names_from_hub_spec(hub_spec))
# Special casing: We allow the name of this arg to differ
js_generate_args = [js_arg for js_arg in hub_args if js_arg.startswith("generate")]
docstring_generate_args = [
docstring_arg for docstring_arg in docstring_args if docstring_arg.startswith("generate")
]
if (
len(js_generate_args) == 1
and len(docstring_generate_args) == 1
and js_generate_args != docstring_generate_args
):
hub_args.remove(js_generate_args[0])
docstring_args.remove(docstring_generate_args[0])
if hub_args != docstring_args:
error = [f"{pipeline_class.__name__} differs from JS spec {hub_spec.__name__}"]
matching_args = hub_args & docstring_args
huggingface_hub_only = hub_args - docstring_args
transformers_only = docstring_args - hub_args
if matching_args:
error.append(f"Matching args: {matching_args}")
if huggingface_hub_only:
error.append(f"Huggingface Hub only: {huggingface_hub_only}")
if transformers_only:
error.append(f"Transformers only: {transformers_only}")
raise ValueError("\n".join(error))