Adding support for fp16 for asr pipeline. (#20864)

* Supporting `fp16` for asr pipeline

* Adding test.

* Style.

* Oops.

* Flake8 update ?

* Fixing flake8 ?

* Revert "Flake8 update ?"

This reverts commit 0b917fcb520e5f34d1933d9d37d8f32b64553048.

* Style (acctidentally deleted flake8 F401.)

* Move to a bigger test (no small whisper model, and s2t doesn't seem to
accept torch_dtype=fp16).

Also we need to use a GPU to actually compute on fp16.

* Using BatchFeature capability.
This commit is contained in:
Nicolas Patry
2022-12-23 10:18:45 +01:00
committed by GitHub
parent 15bc776fec
commit f7f0ec2f54
3 changed files with 26 additions and 3 deletions

View File

@@ -145,6 +145,19 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
with self.assertRaisesRegex(ValueError, "^We cannot return_timestamps yet on non-ctc models !$"):
_ = speech_recognizer(waveform, return_timestamps="char")
@slow
@require_torch
def test_whisper_fp16(self):
if not torch.cuda.is_available():
self.skipTest("Cuda is necessary for this test")
speech_recognizer = pipeline(
model="openai/whisper-base",
device=0,
torch_dtype=torch.float16,
)
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
speech_recognizer(waveform)
@require_torch
def test_small_model_pt_seq2seq(self):
speech_recognizer = pipeline(