[Wav2Vec2 Conformer] Fix inference float16 (#25985)
* [Wav2Vec2 Conformer] Fix inference float16 * fix test * fix test more * clean pipe test
This commit is contained in:
@@ -406,13 +406,15 @@ class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module):
|
|||||||
return self.cached_rotary_positional_embedding
|
return self.cached_rotary_positional_embedding
|
||||||
|
|
||||||
self.cached_sequence_length = sequence_length
|
self.cached_sequence_length = sequence_length
|
||||||
|
# Embeddings are computed in the dtype of the inv_freq constant
|
||||||
time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
|
time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
|
||||||
freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
|
freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
|
||||||
embeddings = torch.cat((freqs, freqs), dim=-1)
|
embeddings = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
|
||||||
cos_embeddings = embeddings.cos()[:, None, None, :]
|
cos_embeddings = embeddings.cos()[:, None, None, :]
|
||||||
sin_embeddings = embeddings.sin()[:, None, None, :]
|
sin_embeddings = embeddings.sin()[:, None, None, :]
|
||||||
self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings])
|
# Computed embeddings are cast to the dtype of the hidden state inputs
|
||||||
|
self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]).type_as(hidden_states)
|
||||||
return self.cached_rotary_positional_embedding
|
return self.cached_rotary_positional_embedding
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -13,15 +13,15 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Testing suite for the PyTorch Wav2Vec2-Conformer model. """
|
""" Testing suite for the PyTorch Wav2Vec2-Conformer model. """
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
from transformers import Wav2Vec2ConformerConfig, is_torch_available
|
from transformers import Wav2Vec2ConformerConfig, is_torch_available
|
||||||
from transformers.testing_utils import is_pt_flax_cross_test, require_torch, slow, torch_device
|
from transformers.testing_utils import is_pt_flax_cross_test, require_torch, require_torch_gpu, slow, torch_device
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import (
|
from ...test_modeling_common import (
|
||||||
@@ -215,6 +215,23 @@ class Wav2Vec2ConformerModelTester:
|
|||||||
(self.batch_size, self.adapter_output_seq_length, config.output_hidden_size),
|
(self.batch_size, self.adapter_output_seq_length, config.output_hidden_size),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def create_and_check_model_float16(self, config, input_values, attention_mask):
|
||||||
|
model = Wav2Vec2ConformerModel(config=config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
model = Wav2Vec2ConformerModel.from_pretrained(tmpdirname, torch_dtype=torch.float16)
|
||||||
|
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
result = model(input_values.type(dtype=torch.float16), attention_mask=attention_mask)
|
||||||
|
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
|
||||||
|
)
|
||||||
|
|
||||||
def create_and_check_batch_inference(self, config, input_values, *args):
|
def create_and_check_batch_inference(self, config, input_values, *args):
|
||||||
# test does not pass for models making use of `group_norm`
|
# test does not pass for models making use of `group_norm`
|
||||||
# check: https://github.com/pytorch/fairseq/issues/3227
|
# check: https://github.com/pytorch/fairseq/issues/3227
|
||||||
@@ -451,6 +468,16 @@ class Wav2Vec2ConformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs)
|
self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs)
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_model_float16_with_relative(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="relative")
|
||||||
|
self.model_tester.create_and_check_model_float16(*config_and_inputs)
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_model_float16_with_rotary(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="rotary")
|
||||||
|
self.model_tester.create_and_check_model_float16(*config_and_inputs)
|
||||||
|
|
||||||
def test_ctc_loss_inference(self):
|
def test_ctc_loss_inference(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.check_ctc_loss(*config_and_inputs)
|
self.model_tester.check_ctc_loss(*config_and_inputs)
|
||||||
|
|||||||
@@ -901,6 +901,26 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
output = speech_recognizer(filename)
|
output = speech_recognizer(filename)
|
||||||
self.assertEqual(output, {"text": "a man said to the universe sir i exist"})
|
self.assertEqual(output, {"text": "a man said to the universe sir i exist"})
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_wav2vec2_conformer_float16(self):
|
||||||
|
speech_recognizer = pipeline(
|
||||||
|
task="automatic-speech-recognition",
|
||||||
|
model="facebook/wav2vec2-conformer-rope-large-960h-ft",
|
||||||
|
device="cuda:0",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
framework="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||||
|
sample = dataset[0]["audio"]
|
||||||
|
|
||||||
|
output = speech_recognizer(sample)
|
||||||
|
self.assertEqual(
|
||||||
|
output,
|
||||||
|
{"text": "MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL"},
|
||||||
|
)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_chunking_fast(self):
|
def test_chunking_fast(self):
|
||||||
speech_recognizer = pipeline(
|
speech_recognizer = pipeline(
|
||||||
|
|||||||
Reference in New Issue
Block a user