Add out of vocabulary error to ASR models (#12288)
* Add OOV error to ASR models * Feedback changes
This commit is contained in:
@@ -20,6 +20,7 @@ import math
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from transformers import Wav2Vec2Config, is_tf_available
|
||||
from transformers.testing_utils import require_datasets, require_soundfile, require_tf, slow
|
||||
@@ -202,6 +203,14 @@ class TFWav2Vec2ModelTester:
|
||||
|
||||
self.parent.assertFalse(tf.math.is_inf(loss))
|
||||
|
||||
def check_labels_out_of_vocab(self, config, input_values, *args):
|
||||
model = TFWav2Vec2ForCTC(config)
|
||||
input_lengths = tf.constant([input_values.shape[-1] // i for i in [4, 2, 1]])
|
||||
max_length_labels = model.wav2vec2._get_feat_extract_output_lengths(input_lengths)
|
||||
labels = ids_tensor((input_values.shape[0], min(max_length_labels) - 1), model.config.vocab_size + 100)
|
||||
with pytest.raises(ValueError):
|
||||
model(input_values, labels=labels)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, input_values, attention_mask = self.prepare_config_and_inputs()
|
||||
inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
|
||||
@@ -288,6 +297,10 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_ctc_loss(*config_and_inputs)
|
||||
|
||||
def test_labels_out_of_vocab(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
||||
|
||||
def test_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_training(*config_and_inputs)
|
||||
@@ -402,6 +415,10 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_ctc_loss(*config_and_inputs)
|
||||
|
||||
def test_labels_out_of_vocab(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
||||
|
||||
def test_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_training(*config_and_inputs)
|
||||
|
||||
Reference in New Issue
Block a user