Add numpy alternative to FE using torchaudio (#26339)
* add audio_utils usage in the FE of SpeechToText * clean unecessary parameters of AudioSpectrogramTransformer FE * add audio_utils usage in AST * add serialization tests and function to FEs * make style * remove use_torchaudio and move to_dict to FE * test audio_utils usage * make style and fix import (remove torchaudio dependency import) * fix torch dependency for jax and tensor tests * fix typo * clean tests with suggestions * add lines to test if is_speech_availble is False
This commit is contained in:
@@ -15,13 +15,15 @@
|
||||
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import ASTFeatureExtractor
|
||||
from transformers.testing_utils import require_torch, require_torchaudio
|
||||
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio
|
||||
from transformers.utils.import_utils import is_torch_available
|
||||
|
||||
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
|
||||
@@ -173,3 +175,48 @@ class ASTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Test
|
||||
input_values = feature_extractor(input_speech, return_tensors="pt").input_values
|
||||
self.assertEquals(input_values.shape, (1, 1024, 128))
|
||||
self.assertTrue(torch.allclose(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, atol=1e-4))
|
||||
|
||||
def test_feat_extract_from_and_save_pretrained(self):
|
||||
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
|
||||
check_json_file_has_correct_format(saved_file)
|
||||
feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname)
|
||||
|
||||
dict_first = feat_extract_first.to_dict()
|
||||
dict_second = feat_extract_second.to_dict()
|
||||
self.assertDictEqual(dict_first, dict_second)
|
||||
|
||||
def test_feat_extract_to_json_file(self):
|
||||
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
json_file_path = os.path.join(tmpdirname, "feat_extract.json")
|
||||
feat_extract_first.to_json_file(json_file_path)
|
||||
feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path)
|
||||
|
||||
dict_first = feat_extract_first.to_dict()
|
||||
dict_second = feat_extract_second.to_dict()
|
||||
self.assertEqual(dict_first, dict_second)
|
||||
|
||||
|
||||
# exact same tests than before, except that we simulate that torchaudio is not available
|
||||
@require_torch
|
||||
@unittest.mock.patch(
|
||||
"transformers.models.audio_spectrogram_transformer.feature_extraction_audio_spectrogram_transformer.is_speech_available",
|
||||
lambda: False,
|
||||
)
|
||||
class ASTFeatureExtractionWithoutTorchaudioTest(ASTFeatureExtractionTest):
|
||||
def test_using_audio_utils(self):
|
||||
# Tests that it uses audio_utils instead of torchaudio
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
|
||||
self.assertTrue(hasattr(feat_extract, "window"))
|
||||
self.assertTrue(hasattr(feat_extract, "mel_filters"))
|
||||
|
||||
from transformers.models.audio_spectrogram_transformer.feature_extraction_audio_spectrogram_transformer import (
|
||||
is_speech_available,
|
||||
)
|
||||
|
||||
self.assertFalse(is_speech_available())
|
||||
|
||||
@@ -15,20 +15,19 @@
|
||||
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import is_speech_available
|
||||
from transformers.testing_utils import require_torch, require_torchaudio
|
||||
from transformers import Speech2TextFeatureExtractor
|
||||
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio
|
||||
|
||||
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
|
||||
|
||||
|
||||
if is_speech_available():
|
||||
from transformers import Speech2TextFeatureExtractor
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
@@ -105,7 +104,7 @@ class Speech2TextFeatureExtractionTester(unittest.TestCase):
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
|
||||
feature_extraction_class = Speech2TextFeatureExtractor if is_speech_available() else None
|
||||
feature_extraction_class = Speech2TextFeatureExtractor
|
||||
|
||||
def setUp(self):
|
||||
self.feat_extract_tester = Speech2TextFeatureExtractionTester(self)
|
||||
@@ -280,3 +279,45 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
|
||||
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
|
||||
self.assertEquals(input_features.shape, (1, 584, 24))
|
||||
self.assertTrue(np.allclose(input_features[0, 0, :30], expected, atol=1e-4))
|
||||
|
||||
def test_feat_extract_from_and_save_pretrained(self):
|
||||
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
|
||||
check_json_file_has_correct_format(saved_file)
|
||||
feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname)
|
||||
|
||||
dict_first = feat_extract_first.to_dict()
|
||||
dict_second = feat_extract_second.to_dict()
|
||||
self.assertDictEqual(dict_first, dict_second)
|
||||
|
||||
def test_feat_extract_to_json_file(self):
|
||||
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
json_file_path = os.path.join(tmpdirname, "feat_extract.json")
|
||||
feat_extract_first.to_json_file(json_file_path)
|
||||
feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path)
|
||||
|
||||
dict_first = feat_extract_first.to_dict()
|
||||
dict_second = feat_extract_second.to_dict()
|
||||
self.assertEqual(dict_first, dict_second)
|
||||
|
||||
|
||||
# exact same tests than before, except that we simulate that torchaudio is not available
|
||||
@require_torch
|
||||
@unittest.mock.patch(
|
||||
"transformers.models.speech_to_text.feature_extraction_speech_to_text.is_speech_available", lambda: False
|
||||
)
|
||||
class Speech2TextFeatureExtractionWithoutTorchaudioTest(Speech2TextFeatureExtractionTest):
|
||||
def test_using_audio_utils(self):
|
||||
# Tests that it uses audio_utils instead of torchaudio
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
|
||||
self.assertTrue(hasattr(feat_extract, "window"))
|
||||
self.assertTrue(hasattr(feat_extract, "mel_filters"))
|
||||
|
||||
from transformers.models.speech_to_text.feature_extraction_speech_to_text import is_speech_available
|
||||
|
||||
self.assertFalse(is_speech_available())
|
||||
|
||||
@@ -18,7 +18,7 @@ import unittest
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
|
||||
from transformers import Speech2TextTokenizer, is_speech_available
|
||||
from transformers import Speech2TextFeatureExtractor, Speech2TextProcessor, Speech2TextTokenizer
|
||||
from transformers.models.speech_to_text.tokenization_speech_to_text import VOCAB_FILES_NAMES, save_json
|
||||
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_torch, require_torchaudio
|
||||
from transformers.utils import FEATURE_EXTRACTOR_NAME
|
||||
@@ -26,10 +26,6 @@ from transformers.utils import FEATURE_EXTRACTOR_NAME
|
||||
from .test_feature_extraction_speech_to_text import floats_list
|
||||
|
||||
|
||||
if is_speech_available():
|
||||
from transformers import Speech2TextFeatureExtractor, Speech2TextProcessor
|
||||
|
||||
|
||||
SAMPLE_SP = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user