From 0de15c988b0d27758ce360adb2627e9ea99e91b3 Mon Sep 17 00:00:00 2001 From: Sambhav Dixit <94298612+sambhavnoobcoder@users.noreply.github.com> Date: Wed, 5 Feb 2025 21:55:08 +0530 Subject: [PATCH] Fix Audio Classification Pipeline top_k Documentation Mismatch and Bug #35736 (#35771) * added condition for top_k Doc mismatch fix * initilation of test file for top_k changes * added test for returning all labels * added test for few labels * tests/test_audio_classification_top_k.py * final fix * ruff fix --------- Co-authored-by: sambhavnoobcoder --- .../pipelines/audio_classification.py | 15 +++-- tests/test_audio_classification_top_k.py | 60 +++++++++++++++++++ 2 files changed, 71 insertions(+), 4 deletions(-) create mode 100644 tests/test_audio_classification_top_k.py diff --git a/src/transformers/pipelines/audio_classification.py b/src/transformers/pipelines/audio_classification.py index 4febb09e95..86dfe72d3c 100644 --- a/src/transformers/pipelines/audio_classification.py +++ b/src/transformers/pipelines/audio_classification.py @@ -91,8 +91,11 @@ class AudioClassificationPipeline(Pipeline): """ def __init__(self, *args, **kwargs): - # Default, might be overriden by the model.config. - kwargs["top_k"] = kwargs.get("top_k", 5) + # Only set default top_k if explicitly provided + if "top_k" in kwargs and kwargs["top_k"] is None: + kwargs["top_k"] = None + elif "top_k" not in kwargs: + kwargs["top_k"] = 5 super().__init__(*args, **kwargs) if self.framework != "pt": @@ -141,12 +144,16 @@ class AudioClassificationPipeline(Pipeline): return super().__call__(inputs, **kwargs) def _sanitize_parameters(self, top_k=None, function_to_apply=None, **kwargs): - # No parameters on this pipeline right now postprocess_params = {} - if top_k is not None: + + # If top_k is None, use all labels + if top_k is None: + postprocess_params["top_k"] = self.model.config.num_labels + else: if top_k > self.model.config.num_labels: top_k = self.model.config.num_labels postprocess_params["top_k"] = top_k + if function_to_apply is not None: if function_to_apply not in ["softmax", "sigmoid", "none"]: raise ValueError( diff --git a/tests/test_audio_classification_top_k.py b/tests/test_audio_classification_top_k.py new file mode 100644 index 0000000000..9911bd7323 --- /dev/null +++ b/tests/test_audio_classification_top_k.py @@ -0,0 +1,60 @@ +import unittest + +import numpy as np + +from transformers import pipeline +from transformers.testing_utils import require_torch + + +@require_torch +class AudioClassificationTopKTest(unittest.TestCase): + def test_top_k_none_returns_all_labels(self): + model_name = "superb/wav2vec2-base-superb-ks" # model with more than 5 labels + classification_pipeline = pipeline( + "audio-classification", + model=model_name, + top_k=None, + ) + + # Create dummy input + sampling_rate = 16000 + signal = np.zeros((sampling_rate,), dtype=np.float32) + + result = classification_pipeline(signal) + num_labels = classification_pipeline.model.config.num_labels + + self.assertEqual(len(result), num_labels, "Should return all labels when top_k is None") + + def test_top_k_none_with_few_labels(self): + model_name = "superb/hubert-base-superb-er" # model with fewer labels + classification_pipeline = pipeline( + "audio-classification", + model=model_name, + top_k=None, + ) + + # Create dummy input + sampling_rate = 16000 + signal = np.zeros((sampling_rate,), dtype=np.float32) + + result = classification_pipeline(signal) + num_labels = classification_pipeline.model.config.num_labels + + self.assertEqual(len(result), num_labels, "Should handle models with fewer labels correctly") + + def test_top_k_greater_than_labels(self): + model_name = "superb/hubert-base-superb-er" + classification_pipeline = pipeline( + "audio-classification", + model=model_name, + top_k=100, # intentionally large number + ) + + # Create dummy input + sampling_rate = 16000 + signal = np.zeros((sampling_rate,), dtype=np.float32) + + result = classification_pipeline(signal) + num_labels = classification_pipeline.model.config.num_labels + + self.assertEqual(len(result), num_labels, "Should cap top_k to number of labels")