From 8d518013efbd10c178dd0dba0f9ba93229e2e78a Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Tue, 5 Sep 2023 18:26:06 +0100 Subject: [PATCH] [Wav2Vec2 Conformer] Fix inference float16 (#25985) * [Wav2Vec2 Conformer] Fix inference float16 * fix test * fix test more * clean pipe test --- .../modeling_wav2vec2_conformer.py | 4 ++- .../test_modeling_wav2vec2_conformer.py | 31 +++++++++++++++++-- ..._pipelines_automatic_speech_recognition.py | 20 ++++++++++++ 3 files changed, 52 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index 76ed22f70e..5041039a8e 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -406,13 +406,15 @@ class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module): return self.cached_rotary_positional_embedding self.cached_sequence_length = sequence_length + # Embeddings are computed in the dtype of the inv_freq constant time_stamps = torch.arange(sequence_length).type_as(self.inv_freq) freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq) embeddings = torch.cat((freqs, freqs), dim=-1) cos_embeddings = embeddings.cos()[:, None, None, :] sin_embeddings = embeddings.sin()[:, None, None, :] - self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]) + # Computed embeddings are cast to the dtype of the hidden state inputs + self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]).type_as(hidden_states) return self.cached_rotary_positional_embedding diff --git a/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py b/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py index 8bd2a2f696..fede8fb967 100644 --- a/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py +++ b/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py @@ -13,15 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Testing suite for the PyTorch Wav2Vec2-Conformer model. """ - import math +import tempfile import unittest import numpy as np from datasets import load_dataset from transformers import Wav2Vec2ConformerConfig, is_torch_available -from transformers.testing_utils import is_pt_flax_cross_test, require_torch, slow, torch_device +from transformers.testing_utils import is_pt_flax_cross_test, require_torch, require_torch_gpu, slow, torch_device from ...test_configuration_common import ConfigTester from ...test_modeling_common import ( @@ -215,6 +215,23 @@ class Wav2Vec2ConformerModelTester: (self.batch_size, self.adapter_output_seq_length, config.output_hidden_size), ) + def create_and_check_model_float16(self, config, input_values, attention_mask): + model = Wav2Vec2ConformerModel(config=config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = Wav2Vec2ConformerModel.from_pretrained(tmpdirname, torch_dtype=torch.float16) + + model.to(torch_device) + model.eval() + + with torch.no_grad(): + result = model(input_values.type(dtype=torch.float16), attention_mask=attention_mask) + + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size) + ) + def create_and_check_batch_inference(self, config, input_values, *args): # test does not pass for models making use of `group_norm` # check: https://github.com/pytorch/fairseq/issues/3227 @@ -451,6 +468,16 @@ class Wav2Vec2ConformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs) + @require_torch_gpu + def test_model_float16_with_relative(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="relative") + self.model_tester.create_and_check_model_float16(*config_and_inputs) + + @require_torch_gpu + def test_model_float16_with_rotary(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="rotary") + self.model_tester.create_and_check_model_float16(*config_and_inputs) + def test_ctc_loss_inference(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_ctc_loss(*config_and_inputs) diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 51747482ce..9ff171e867 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -901,6 +901,26 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): output = speech_recognizer(filename) self.assertEqual(output, {"text": "a man said to the universe sir i exist"}) + @slow + @require_torch_gpu + def test_wav2vec2_conformer_float16(self): + speech_recognizer = pipeline( + task="automatic-speech-recognition", + model="facebook/wav2vec2-conformer-rope-large-960h-ft", + device="cuda:0", + torch_dtype=torch.float16, + framework="pt", + ) + + dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + sample = dataset[0]["audio"] + + output = speech_recognizer(sample) + self.assertEqual( + output, + {"text": "MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL"}, + ) + @require_torch def test_chunking_fast(self): speech_recognizer = pipeline(