[Wav2Vec2 Conformer] Fix inference float16 (#25985)

* [Wav2Vec2 Conformer] Fix inference float16

* fix test

* fix test more

* clean pipe test
This commit is contained in:
Sanchit Gandhi
2023-09-05 18:26:06 +01:00
committed by GitHub
parent 6bc517ccd4
commit 8d518013ef
3 changed files with 52 additions and 3 deletions

View File

@@ -13,15 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch Wav2Vec2-Conformer model. """
import math
import tempfile
import unittest
import numpy as np
from datasets import load_dataset
from transformers import Wav2Vec2ConformerConfig, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_torch, slow, torch_device
from transformers.testing_utils import is_pt_flax_cross_test, require_torch, require_torch_gpu, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
@@ -215,6 +215,23 @@ class Wav2Vec2ConformerModelTester:
(self.batch_size, self.adapter_output_seq_length, config.output_hidden_size),
)
def create_and_check_model_float16(self, config, input_values, attention_mask):
model = Wav2Vec2ConformerModel(config=config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = Wav2Vec2ConformerModel.from_pretrained(tmpdirname, torch_dtype=torch.float16)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(input_values.type(dtype=torch.float16), attention_mask=attention_mask)
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
)
def create_and_check_batch_inference(self, config, input_values, *args):
# test does not pass for models making use of `group_norm`
# check: https://github.com/pytorch/fairseq/issues/3227
@@ -451,6 +468,16 @@ class Wav2Vec2ConformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs)
@require_torch_gpu
def test_model_float16_with_relative(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="relative")
self.model_tester.create_and_check_model_float16(*config_and_inputs)
@require_torch_gpu
def test_model_float16_with_rotary(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="rotary")
self.model_tester.create_and_check_model_float16(*config_and_inputs)
def test_ctc_loss_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_loss(*config_and_inputs)