[Whisper] Computing features on GPU in batch mode for whisper feature extractor. (#29900)
* add _torch_extract_fbank_features_batch function in feature_extractor_whisper * reformat feature_extraction_whisper.py file * handle batching in single function * add gpu test & doc * add batch test & device in each __call__ * add device arg in doc string --------- Co-authored-by: vaibhav.aggarwal <vaibhav.aggarwal@sprinklr.com>
This commit is contained in:
@@ -24,7 +24,7 @@ import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import WhisperFeatureExtractor
|
||||
from transformers.testing_utils import check_json_file_has_correct_format, require_torch
|
||||
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torch_gpu
|
||||
from transformers.utils.import_utils import is_torch_available
|
||||
|
||||
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
|
||||
@@ -207,6 +207,7 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||
|
||||
return [x["array"] for x in speech_samples]
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch
|
||||
def test_torch_integration(self):
|
||||
# fmt: off
|
||||
@@ -223,6 +224,7 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||
input_speech = self._load_datasamples(1)
|
||||
feature_extractor = WhisperFeatureExtractor()
|
||||
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
|
||||
|
||||
self.assertEqual(input_features.shape, (1, 80, 3000))
|
||||
self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
|
||||
|
||||
@@ -253,3 +255,37 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||
|
||||
self.assertTrue(np.all(np.mean(audio) < 1e-3))
|
||||
self.assertTrue(np.all(np.abs(np.var(audio) - 1) < 1e-3))
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch
|
||||
def test_torch_integration_batch(self):
|
||||
# fmt: off
|
||||
EXPECTED_INPUT_FEATURES = torch.tensor(
|
||||
[
|
||||
[
|
||||
0.1193, -0.0946, -0.1098, -0.0196, 0.0225, -0.0690, -0.1736, 0.0951,
|
||||
0.0971, -0.0817, -0.0702, 0.0162, 0.0260, 0.0017, -0.0192, -0.1678,
|
||||
0.0709, -0.1867, -0.0655, -0.0274, -0.0234, -0.1884, -0.0516, -0.0554,
|
||||
-0.0274, -0.1425, -0.1423, 0.0837, 0.0377, -0.0854
|
||||
],
|
||||
[
|
||||
-0.4696, -0.0751, 0.0276, -0.0312, -0.0540, -0.0383, 0.1295, 0.0568,
|
||||
-0.2071, -0.0548, 0.0389, -0.0316, -0.2346, -0.1068, -0.0322, 0.0475,
|
||||
-0.1709, -0.0041, 0.0872, 0.0537, 0.0075, -0.0392, 0.0371, 0.0189,
|
||||
-0.1522, -0.0270, 0.0744, 0.0738, -0.0245, -0.0667
|
||||
],
|
||||
[
|
||||
-0.2337, -0.0060, -0.0063, -0.2353, -0.0431, 0.1102, -0.1492, -0.0292,
|
||||
0.0787, -0.0608, 0.0143, 0.0582, 0.0072, 0.0101, -0.0444, -0.1701,
|
||||
-0.0064, -0.0027, -0.0826, -0.0730, -0.0099, -0.0762, -0.0170, 0.0446,
|
||||
-0.1153, 0.0960, -0.0361, 0.0652, 0.1207, 0.0277
|
||||
]
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
input_speech = self._load_datasamples(3)
|
||||
feature_extractor = WhisperFeatureExtractor()
|
||||
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
|
||||
self.assertEqual(input_features.shape, (3, 80, 3000))
|
||||
self.assertTrue(torch.allclose(input_features[:, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
|
||||
|
||||
Reference in New Issue
Block a user