[Wav2Vec2 Conformer] Fix inference float16 (#25985)
* [Wav2Vec2 Conformer] Fix inference float16 * fix test * fix test more * clean pipe test
This commit is contained in:
@@ -901,6 +901,26 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
output = speech_recognizer(filename)
|
||||
self.assertEqual(output, {"text": "a man said to the universe sir i exist"})
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_wav2vec2_conformer_float16(self):
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="facebook/wav2vec2-conformer-rope-large-960h-ft",
|
||||
device="cuda:0",
|
||||
torch_dtype=torch.float16,
|
||||
framework="pt",
|
||||
)
|
||||
|
||||
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
sample = dataset[0]["audio"]
|
||||
|
||||
output = speech_recognizer(sample)
|
||||
self.assertEqual(
|
||||
output,
|
||||
{"text": "MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL"},
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_chunking_fast(self):
|
||||
speech_recognizer = pipeline(
|
||||
|
||||
Reference in New Issue
Block a user