[Whisper] Use torch for stft if available (#26119)
* [Whisper] Use torch for stft if available * update docstring * mock patch decorator * fit on one line
This commit is contained in:
@@ -23,16 +23,13 @@ import unittest
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import is_speech_available
|
||||
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio
|
||||
from transformers import WhisperFeatureExtractor
|
||||
from transformers.testing_utils import check_json_file_has_correct_format, require_torch
|
||||
from transformers.utils.import_utils import is_torch_available
|
||||
|
||||
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
|
||||
|
||||
|
||||
if is_speech_available():
|
||||
from transformers import WhisperFeatureExtractor
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
@@ -53,8 +50,6 @@ def floats_list(shape, scale=1.0, rng=None, name=None):
|
||||
return values
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
class WhisperFeatureExtractionTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -111,10 +106,8 @@ class WhisperFeatureExtractionTester(unittest.TestCase):
|
||||
return speech_inputs
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
|
||||
feature_extraction_class = WhisperFeatureExtractor if is_speech_available() else None
|
||||
feature_extraction_class = WhisperFeatureExtractor
|
||||
|
||||
def setUp(self):
|
||||
self.feat_extract_tester = WhisperFeatureExtractionTester(self)
|
||||
@@ -193,6 +186,7 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||
|
||||
@require_torch
|
||||
def test_double_precision_pad(self):
|
||||
import torch
|
||||
|
||||
@@ -213,7 +207,8 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||
|
||||
return [x["array"] for x in speech_samples]
|
||||
|
||||
def test_integration(self):
|
||||
@require_torch
|
||||
def test_torch_integration(self):
|
||||
# fmt: off
|
||||
EXPECTED_INPUT_FEATURES = torch.tensor(
|
||||
[
|
||||
@@ -231,6 +226,25 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||
self.assertEqual(input_features.shape, (1, 80, 3000))
|
||||
self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
|
||||
|
||||
@unittest.mock.patch("transformers.models.whisper.feature_extraction_whisper.is_torch_available", lambda: False)
|
||||
def test_numpy_integration(self):
|
||||
# fmt: off
|
||||
EXPECTED_INPUT_FEATURES = np.array(
|
||||
[
|
||||
0.1193, -0.0946, -0.1098, -0.0196, 0.0225, -0.0690, -0.1736, 0.0951,
|
||||
0.0971, -0.0817, -0.0702, 0.0162, 0.0260, 0.0017, -0.0192, -0.1678,
|
||||
0.0709, -0.1867, -0.0655, -0.0274, -0.0234, -0.1884, -0.0516, -0.0554,
|
||||
-0.0274, -0.1425, -0.1423, 0.0837, 0.0377, -0.0854
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
input_speech = self._load_datasamples(1)
|
||||
feature_extractor = WhisperFeatureExtractor()
|
||||
input_features = feature_extractor(input_speech, return_tensors="np").input_features
|
||||
self.assertEqual(input_features.shape, (1, 80, 3000))
|
||||
self.assertTrue(np.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
|
||||
|
||||
def test_zero_mean_unit_variance_normalization_trunc_np_longest(self):
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
audio = self._load_datasamples(1)[0]
|
||||
|
||||
Reference in New Issue
Block a user