Make audio classification pipeline spec-compliant and add test (#33730)
* Make audio classification pipeline spec-compliant and add test * Check that test actually running in CI * Try a different pipeline for the CI * Move the test so it gets triggered * Move it again, this time into task_tests! * make fixup * indentation fix * comment * Move everything from testing_utils to test_pipeline_mixin * Add output testing too * revert small diff with main * make fixup * Clarify comment * Update tests/pipelines/test_pipelines_audio_classification.py Co-authored-by: Lucain <lucainp@gmail.com> * Update tests/test_pipeline_mixin.py Co-authored-by: Lucain <lucainp@gmail.com> * Rename function and js_args -> hub_args * Cleanup the spec recursion * Check keys for all outputs --------- Co-authored-by: Lucain <lucainp@gmail.com>
This commit is contained in:
@@ -13,8 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from dataclasses import fields
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import AudioClassificationOutputElement
|
||||
|
||||
from transformers import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
|
||||
from transformers.pipelines import AudioClassificationPipeline, pipeline
|
||||
@@ -66,6 +68,11 @@ class AudioClassificationPipelineTests(unittest.TestCase):
|
||||
|
||||
self.run_torchaudio(audio_classifier)
|
||||
|
||||
spec_output_keys = {field.name for field in fields(AudioClassificationOutputElement)}
|
||||
for single_output in output:
|
||||
output_keys = set(single_output.keys())
|
||||
self.assertEqual(spec_output_keys, output_keys, msg="Pipeline output keys do not match HF Hub spec!")
|
||||
|
||||
@require_torchaudio
|
||||
def run_torchaudio(self, audio_classifier):
|
||||
import datasets
|
||||
|
||||
Reference in New Issue
Block a user