[ASR pipeline] correct with lm pipeline (#15200)

* [ASR pipeline] correct with lm pipeline

* improve error
This commit is contained in:
Patrick von Platen
2022-01-18 15:36:22 +01:00
committed by GitHub
parent 1144d336b6
commit 497346d07e
6 changed files with 36 additions and 9 deletions

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import tempfile
import unittest
@@ -42,8 +43,9 @@ class AutoFeatureExtractorTest(unittest.TestCase):
# remove feature_extractor_type to make sure config.json alone is enough to load feature processor locally
config_dict = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR).to_dict()
config_dict.pop("feature_extractor_type")
config = Wav2Vec2FeatureExtractor(config_dict)
config = Wav2Vec2FeatureExtractor(**config_dict)
# save in new folder
model_config.save_pretrained(tmpdirname)
@@ -51,6 +53,10 @@ class AutoFeatureExtractorTest(unittest.TestCase):
config = AutoFeatureExtractor.from_pretrained(tmpdirname)
# make sure private variable is not incorrectly saved
dict_as_saved = json.loads(config.to_json_string())
self.assertTrue("_processor_class" not in dict_as_saved)
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
def test_feature_extractor_from_local_file(self):

View File

@@ -295,6 +295,26 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
self.assertEqual(output, [{"text": ANY(str)}])
self.assertEqual(output[0]["text"][:6], "ZBT ZC")
@require_torch
@require_pyctcdecode
def test_with_lm_fast(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="hf-internal-testing/processor_with_lm",
framework="pt",
)
self.assertEqual(speech_recognizer.type, "ctc_with_lm")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
audio = ds[40]["audio"]["array"]
n_repeats = 2
audio_tiled = np.tile(audio, n_repeats)
output = speech_recognizer([audio_tiled], batch_size=2)
self.assertEqual(output, [{"text": ANY(str)}])
self.assertEqual(output[0]["text"][:6], "<s> <s")
@require_torch
@slow
def test_chunking(self):