[ASR pipeline] correct with lm pipeline (#15200)
* [ASR pipeline] correct with lm pipeline * improve error
This commit is contained in:
committed by
GitHub
parent
1144d336b6
commit
497346d07e
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
@@ -42,8 +43,9 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
||||
|
||||
# remove feature_extractor_type to make sure config.json alone is enough to load feature processor locally
|
||||
config_dict = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR).to_dict()
|
||||
|
||||
config_dict.pop("feature_extractor_type")
|
||||
config = Wav2Vec2FeatureExtractor(config_dict)
|
||||
config = Wav2Vec2FeatureExtractor(**config_dict)
|
||||
|
||||
# save in new folder
|
||||
model_config.save_pretrained(tmpdirname)
|
||||
@@ -51,6 +53,10 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
||||
|
||||
config = AutoFeatureExtractor.from_pretrained(tmpdirname)
|
||||
|
||||
# make sure private variable is not incorrectly saved
|
||||
dict_as_saved = json.loads(config.to_json_string())
|
||||
self.assertTrue("_processor_class" not in dict_as_saved)
|
||||
|
||||
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
|
||||
|
||||
def test_feature_extractor_from_local_file(self):
|
||||
|
||||
@@ -295,6 +295,26 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
self.assertEqual(output, [{"text": ANY(str)}])
|
||||
self.assertEqual(output[0]["text"][:6], "ZBT ZC")
|
||||
|
||||
@require_torch
|
||||
@require_pyctcdecode
|
||||
def test_with_lm_fast(self):
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="hf-internal-testing/processor_with_lm",
|
||||
framework="pt",
|
||||
)
|
||||
self.assertEqual(speech_recognizer.type, "ctc_with_lm")
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
audio = ds[40]["audio"]["array"]
|
||||
|
||||
n_repeats = 2
|
||||
audio_tiled = np.tile(audio, n_repeats)
|
||||
output = speech_recognizer([audio_tiled], batch_size=2)
|
||||
|
||||
self.assertEqual(output, [{"text": ANY(str)}])
|
||||
self.assertEqual(output[0]["text"][:6], "<s> <s")
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_chunking(self):
|
||||
|
||||
Reference in New Issue
Block a user