Adding support for fp16 for asr pipeline. (#20864)
* Supporting `fp16` for asr pipeline * Adding test. * Style. * Oops. * Flake8 update ? * Fixing flake8 ? * Revert "Flake8 update ?" This reverts commit 0b917fcb520e5f34d1933d9d37d8f32b64553048. * Style (acctidentally deleted flake8 F401.) * Move to a bigger test (no small whisper model, and s2t doesn't seem to accept torch_dtype=fp16). Also we need to use a GPU to actually compute on fp16. * Using BatchFeature capability.
This commit is contained in:
@@ -145,6 +145,19 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
with self.assertRaisesRegex(ValueError, "^We cannot return_timestamps yet on non-ctc models !$"):
|
||||
_ = speech_recognizer(waveform, return_timestamps="char")
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_whisper_fp16(self):
|
||||
if not torch.cuda.is_available():
|
||||
self.skipTest("Cuda is necessary for this test")
|
||||
speech_recognizer = pipeline(
|
||||
model="openai/whisper-base",
|
||||
device=0,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
|
||||
speech_recognizer(waveform)
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt_seq2seq(self):
|
||||
speech_recognizer = pipeline(
|
||||
|
||||
Reference in New Issue
Block a user