[Wav2Vec2 Conformer] Fix inference float16 (#25985)
* [Wav2Vec2 Conformer] Fix inference float16 * fix test * fix test more * clean pipe test
This commit is contained in:
@@ -406,13 +406,15 @@ class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module):
|
||||
return self.cached_rotary_positional_embedding
|
||||
|
||||
self.cached_sequence_length = sequence_length
|
||||
# Embeddings are computed in the dtype of the inv_freq constant
|
||||
time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
|
||||
freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
|
||||
embeddings = torch.cat((freqs, freqs), dim=-1)
|
||||
|
||||
cos_embeddings = embeddings.cos()[:, None, None, :]
|
||||
sin_embeddings = embeddings.sin()[:, None, None, :]
|
||||
self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings])
|
||||
# Computed embeddings are cast to the dtype of the hidden state inputs
|
||||
self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]).type_as(hidden_states)
|
||||
return self.cached_rotary_positional_embedding
|
||||
|
||||
|
||||
|
||||
@@ -13,15 +13,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch Wav2Vec2-Conformer model. """
|
||||
|
||||
import math
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import Wav2Vec2ConformerConfig, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_torch, slow, torch_device
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_torch, require_torch_gpu, slow, torch_device
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import (
|
||||
@@ -215,6 +215,23 @@ class Wav2Vec2ConformerModelTester:
|
||||
(self.batch_size, self.adapter_output_seq_length, config.output_hidden_size),
|
||||
)
|
||||
|
||||
def create_and_check_model_float16(self, config, input_values, attention_mask):
|
||||
model = Wav2Vec2ConformerModel(config=config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = Wav2Vec2ConformerModel.from_pretrained(tmpdirname, torch_dtype=torch.float16)
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
result = model(input_values.type(dtype=torch.float16), attention_mask=attention_mask)
|
||||
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
|
||||
)
|
||||
|
||||
def create_and_check_batch_inference(self, config, input_values, *args):
|
||||
# test does not pass for models making use of `group_norm`
|
||||
# check: https://github.com/pytorch/fairseq/issues/3227
|
||||
@@ -451,6 +468,16 @@ class Wav2Vec2ConformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_model_float16_with_relative(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="relative")
|
||||
self.model_tester.create_and_check_model_float16(*config_and_inputs)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_model_float16_with_rotary(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="rotary")
|
||||
self.model_tester.create_and_check_model_float16(*config_and_inputs)
|
||||
|
||||
def test_ctc_loss_inference(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_ctc_loss(*config_and_inputs)
|
||||
|
||||
@@ -901,6 +901,26 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
output = speech_recognizer(filename)
|
||||
self.assertEqual(output, {"text": "a man said to the universe sir i exist"})
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_wav2vec2_conformer_float16(self):
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="facebook/wav2vec2-conformer-rope-large-960h-ft",
|
||||
device="cuda:0",
|
||||
torch_dtype=torch.float16,
|
||||
framework="pt",
|
||||
)
|
||||
|
||||
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
sample = dataset[0]["audio"]
|
||||
|
||||
output = speech_recognizer(sample)
|
||||
self.assertEqual(
|
||||
output,
|
||||
{"text": "MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL"},
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_chunking_fast(self):
|
||||
speech_recognizer = pipeline(
|
||||
|
||||
Reference in New Issue
Block a user