Add numpy alternative to FE using torchaudio (#26339)
* add audio_utils usage in the FE of SpeechToText * clean unecessary parameters of AudioSpectrogramTransformer FE * add audio_utils usage in AST * add serialization tests and function to FEs * make style * remove use_torchaudio and move to_dict to FE * test audio_utils usage * make style and fix import (remove torchaudio dependency import) * fix torch dependency for jax and tensor tests * fix typo * clean tests with suggestions * add lines to test if is_speech_availble is False
This commit is contained in:
@@ -15,13 +15,15 @@
|
||||
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import ASTFeatureExtractor
|
||||
from transformers.testing_utils import require_torch, require_torchaudio
|
||||
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio
|
||||
from transformers.utils.import_utils import is_torch_available
|
||||
|
||||
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
|
||||
@@ -173,3 +175,48 @@ class ASTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Test
|
||||
input_values = feature_extractor(input_speech, return_tensors="pt").input_values
|
||||
self.assertEquals(input_values.shape, (1, 1024, 128))
|
||||
self.assertTrue(torch.allclose(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, atol=1e-4))
|
||||
|
||||
def test_feat_extract_from_and_save_pretrained(self):
|
||||
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
|
||||
check_json_file_has_correct_format(saved_file)
|
||||
feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname)
|
||||
|
||||
dict_first = feat_extract_first.to_dict()
|
||||
dict_second = feat_extract_second.to_dict()
|
||||
self.assertDictEqual(dict_first, dict_second)
|
||||
|
||||
def test_feat_extract_to_json_file(self):
|
||||
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
json_file_path = os.path.join(tmpdirname, "feat_extract.json")
|
||||
feat_extract_first.to_json_file(json_file_path)
|
||||
feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path)
|
||||
|
||||
dict_first = feat_extract_first.to_dict()
|
||||
dict_second = feat_extract_second.to_dict()
|
||||
self.assertEqual(dict_first, dict_second)
|
||||
|
||||
|
||||
# exact same tests than before, except that we simulate that torchaudio is not available
|
||||
@require_torch
|
||||
@unittest.mock.patch(
|
||||
"transformers.models.audio_spectrogram_transformer.feature_extraction_audio_spectrogram_transformer.is_speech_available",
|
||||
lambda: False,
|
||||
)
|
||||
class ASTFeatureExtractionWithoutTorchaudioTest(ASTFeatureExtractionTest):
|
||||
def test_using_audio_utils(self):
|
||||
# Tests that it uses audio_utils instead of torchaudio
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
|
||||
self.assertTrue(hasattr(feat_extract, "window"))
|
||||
self.assertTrue(hasattr(feat_extract, "mel_filters"))
|
||||
|
||||
from transformers.models.audio_spectrogram_transformer.feature_extraction_audio_spectrogram_transformer import (
|
||||
is_speech_available,
|
||||
)
|
||||
|
||||
self.assertFalse(is_speech_available())
|
||||
|
||||
Reference in New Issue
Block a user