[Wav2Vec2 Conformer] Fix inference float16 (#25985)

* [Wav2Vec2 Conformer] Fix inference float16

* fix test

* fix test more

* clean pipe test
This commit is contained in:
Sanchit Gandhi
2023-09-05 18:26:06 +01:00
committed by GitHub
parent 6bc517ccd4
commit 8d518013ef
3 changed files with 52 additions and 3 deletions

View File

@@ -901,6 +901,26 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
output = speech_recognizer(filename)
self.assertEqual(output, {"text": "a man said to the universe sir i exist"})
@slow
@require_torch_gpu
def test_wav2vec2_conformer_float16(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="facebook/wav2vec2-conformer-rope-large-960h-ft",
device="cuda:0",
torch_dtype=torch.float16,
framework="pt",
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
output = speech_recognizer(sample)
self.assertEqual(
output,
{"text": "MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL"},
)
@require_torch
def test_chunking_fast(self):
speech_recognizer = pipeline(