Adding support for fp16 for asr pipeline. (#20864)
* Supporting `fp16` for asr pipeline * Adding test. * Style. * Oops. * Flake8 update ? * Fixing flake8 ? * Revert "Flake8 update ?" This reverts commit 0b917fcb520e5f34d1933d9d37d8f32b64553048. * Style (acctidentally deleted flake8 F401.) * Move to a bigger test (no small whisper model, and s2t doesn't seem to accept torch_dtype=fp16). Also we need to use a GPU to actually compute on fp16. * Using BatchFeature capability.
This commit is contained in:
@@ -874,6 +874,9 @@ def pipeline(
|
||||
if feature_extractor is not None:
|
||||
kwargs["feature_extractor"] = feature_extractor
|
||||
|
||||
if torch_dtype is not None:
|
||||
kwargs["torch_dtype"] = torch_dtype
|
||||
|
||||
if device is not None:
|
||||
kwargs["device"] = device
|
||||
|
||||
|
||||
@@ -52,13 +52,15 @@ def rescale_stride(stride, ratio):
|
||||
return new_strides
|
||||
|
||||
|
||||
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right):
|
||||
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, dtype=None):
|
||||
inputs_len = inputs.shape[0]
|
||||
step = chunk_len - stride_left - stride_right
|
||||
for i in range(0, inputs_len, step):
|
||||
# add start and end paddings to the chunk
|
||||
chunk = inputs[i : i + chunk_len]
|
||||
processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")
|
||||
if dtype is not None:
|
||||
processed = processed.to(dtype=dtype)
|
||||
_stride_left = 0 if i == 0 else stride_left
|
||||
is_last = i + step + stride_left >= inputs_len
|
||||
_stride_right = 0 if is_last else stride_right
|
||||
@@ -240,6 +242,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
preprocess_params["stride_length_s"] = kwargs["stride_length_s"]
|
||||
if "ignore_warning" in kwargs:
|
||||
preprocess_params["ignore_warning"] = kwargs["ignore_warning"]
|
||||
if "torch_dtype" in kwargs:
|
||||
preprocess_params["dtype"] = kwargs["torch_dtype"]
|
||||
|
||||
postprocess_params = {}
|
||||
if "decoder_kwargs" in kwargs:
|
||||
@@ -249,7 +253,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
|
||||
return preprocess_params, {}, postprocess_params
|
||||
|
||||
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False):
|
||||
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False, dtype=None):
|
||||
if isinstance(inputs, str):
|
||||
if inputs.startswith("http://") or inputs.startswith("https://"):
|
||||
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
||||
@@ -332,12 +336,14 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
raise ValueError("Chunk length must be superior to stride length")
|
||||
|
||||
# make sure that
|
||||
for item in chunk_iter(inputs, self.feature_extractor, chunk_len, stride_left, stride_right):
|
||||
for item in chunk_iter(inputs, self.feature_extractor, chunk_len, stride_left, stride_right, dtype):
|
||||
yield item
|
||||
else:
|
||||
processed = self.feature_extractor(
|
||||
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
|
||||
)
|
||||
if dtype is not None:
|
||||
processed = processed.to(dtype=dtype)
|
||||
if stride is not None:
|
||||
if self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
|
||||
raise ValueError("Stride is only usable with CTC models, try removing it")
|
||||
@@ -366,6 +372,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
# `generate` magic to create the mask automatically won't work, we basically need to help
|
||||
# it here.
|
||||
attention_mask = model_inputs.pop("attention_mask", None)
|
||||
|
||||
tokens = self.model.generate(
|
||||
encoder_outputs=encoder(inputs, attention_mask=attention_mask),
|
||||
attention_mask=attention_mask,
|
||||
|
||||
@@ -145,6 +145,19 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
with self.assertRaisesRegex(ValueError, "^We cannot return_timestamps yet on non-ctc models !$"):
|
||||
_ = speech_recognizer(waveform, return_timestamps="char")
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_whisper_fp16(self):
|
||||
if not torch.cuda.is_available():
|
||||
self.skipTest("Cuda is necessary for this test")
|
||||
speech_recognizer = pipeline(
|
||||
model="openai/whisper-base",
|
||||
device=0,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
|
||||
speech_recognizer(waveform)
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt_seq2seq(self):
|
||||
speech_recognizer = pipeline(
|
||||
|
||||
Reference in New Issue
Block a user