[Whisper] Use torch for stft if available (#26119)

* [Whisper] Use torch for stft if available

* update docstring

* mock patch decorator

* fit on one line
This commit is contained in:
Sanchit Gandhi
2023-12-21 11:04:05 +00:00
committed by GitHub
parent 7e93ce40c5
commit 814619f54f
2 changed files with 53 additions and 13 deletions

View File

@@ -23,16 +23,13 @@ import unittest
import numpy as np
from datasets import load_dataset
from transformers import is_speech_available
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio
from transformers import WhisperFeatureExtractor
from transformers.testing_utils import check_json_file_has_correct_format, require_torch
from transformers.utils.import_utils import is_torch_available
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
if is_speech_available():
from transformers import WhisperFeatureExtractor
if is_torch_available():
import torch
@@ -53,8 +50,6 @@ def floats_list(shape, scale=1.0, rng=None, name=None):
return values
@require_torch
@require_torchaudio
class WhisperFeatureExtractionTester(unittest.TestCase):
def __init__(
self,
@@ -111,10 +106,8 @@ class WhisperFeatureExtractionTester(unittest.TestCase):
return speech_inputs
@require_torch
@require_torchaudio
class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = WhisperFeatureExtractor if is_speech_available() else None
feature_extraction_class = WhisperFeatureExtractor
def setUp(self):
self.feat_extract_tester = WhisperFeatureExtractionTester(self)
@@ -193,6 +186,7 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
@require_torch
def test_double_precision_pad(self):
import torch
@@ -213,7 +207,8 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
return [x["array"] for x in speech_samples]
def test_integration(self):
@require_torch
def test_torch_integration(self):
# fmt: off
EXPECTED_INPUT_FEATURES = torch.tensor(
[
@@ -231,6 +226,25 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
self.assertEqual(input_features.shape, (1, 80, 3000))
self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
@unittest.mock.patch("transformers.models.whisper.feature_extraction_whisper.is_torch_available", lambda: False)
def test_numpy_integration(self):
# fmt: off
EXPECTED_INPUT_FEATURES = np.array(
[
0.1193, -0.0946, -0.1098, -0.0196, 0.0225, -0.0690, -0.1736, 0.0951,
0.0971, -0.0817, -0.0702, 0.0162, 0.0260, 0.0017, -0.0192, -0.1678,
0.0709, -0.1867, -0.0655, -0.0274, -0.0234, -0.1884, -0.0516, -0.0554,
-0.0274, -0.1425, -0.1423, 0.0837, 0.0377, -0.0854
]
)
# fmt: on
input_speech = self._load_datasamples(1)
feature_extractor = WhisperFeatureExtractor()
input_features = feature_extractor(input_speech, return_tensors="np").input_features
self.assertEqual(input_features.shape, (1, 80, 3000))
self.assertTrue(np.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
def test_zero_mean_unit_variance_normalization_trunc_np_longest(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
audio = self._load_datasamples(1)[0]