Add out of vocabulary error to ASR models (#12288)
* Add OOV error to ASR models * Feedback changes
This commit is contained in:
@@ -1030,6 +1030,9 @@ class HubertForCTC(HubertPreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
|
||||||
|
if labels.max() >= self.config.vocab_size:
|
||||||
|
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
||||||
|
|
||||||
# retrieve loss input_lengths from attention_mask
|
# retrieve loss input_lengths from attention_mask
|
||||||
attention_mask = (
|
attention_mask = (
|
||||||
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
|
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
|
||||||
|
|||||||
@@ -1571,6 +1571,10 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
|
|||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
|
||||||
|
if tf.reduce_max(labels) >= self.config.vocab_size:
|
||||||
|
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
||||||
|
|
||||||
attention_mask = (
|
attention_mask = (
|
||||||
inputs["attention_mask"]
|
inputs["attention_mask"]
|
||||||
if inputs["attention_mask"] is not None
|
if inputs["attention_mask"] is not None
|
||||||
|
|||||||
@@ -1480,6 +1480,9 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
|
||||||
|
if labels.max() >= self.config.vocab_size:
|
||||||
|
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
||||||
|
|
||||||
# retrieve loss input_lengths from attention_mask
|
# retrieve loss input_lengths from attention_mask
|
||||||
attention_mask = (
|
attention_mask = (
|
||||||
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
|
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
|
||||||
|
|||||||
@@ -18,6 +18,8 @@
|
|||||||
import math
|
import math
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
|
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
|
||||||
@@ -210,6 +212,20 @@ class HubertModelTester:
|
|||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
def check_labels_out_of_vocab(self, config, input_values, *args):
|
||||||
|
model = HubertForCTC(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
input_values = input_values[:3]
|
||||||
|
|
||||||
|
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||||
|
max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
|
||||||
|
labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size + 100)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
model(input_values, labels=labels)
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config, input_values, attention_mask = self.prepare_config_and_inputs()
|
config, input_values, attention_mask = self.prepare_config_and_inputs()
|
||||||
inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
|
inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
|
||||||
@@ -242,6 +258,10 @@ class HubertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.check_training(*config_and_inputs)
|
self.model_tester.check_training(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_labels_out_of_vocab(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
||||||
|
|
||||||
# Hubert has no inputs_embeds
|
# Hubert has no inputs_embeds
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
pass
|
pass
|
||||||
@@ -377,6 +397,10 @@ class HubertRobustModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.check_training(*config_and_inputs)
|
self.model_tester.check_training(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_labels_out_of_vocab(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
||||||
|
|
||||||
# Hubert has no inputs_embeds
|
# Hubert has no inputs_embeds
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import math
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
from transformers import Wav2Vec2Config, is_tf_available
|
from transformers import Wav2Vec2Config, is_tf_available
|
||||||
from transformers.testing_utils import require_datasets, require_soundfile, require_tf, slow
|
from transformers.testing_utils import require_datasets, require_soundfile, require_tf, slow
|
||||||
@@ -202,6 +203,14 @@ class TFWav2Vec2ModelTester:
|
|||||||
|
|
||||||
self.parent.assertFalse(tf.math.is_inf(loss))
|
self.parent.assertFalse(tf.math.is_inf(loss))
|
||||||
|
|
||||||
|
def check_labels_out_of_vocab(self, config, input_values, *args):
|
||||||
|
model = TFWav2Vec2ForCTC(config)
|
||||||
|
input_lengths = tf.constant([input_values.shape[-1] // i for i in [4, 2, 1]])
|
||||||
|
max_length_labels = model.wav2vec2._get_feat_extract_output_lengths(input_lengths)
|
||||||
|
labels = ids_tensor((input_values.shape[0], min(max_length_labels) - 1), model.config.vocab_size + 100)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
model(input_values, labels=labels)
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config, input_values, attention_mask = self.prepare_config_and_inputs()
|
config, input_values, attention_mask = self.prepare_config_and_inputs()
|
||||||
inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
|
inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
|
||||||
@@ -288,6 +297,10 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.check_ctc_loss(*config_and_inputs)
|
self.model_tester.check_ctc_loss(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_labels_out_of_vocab(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
||||||
|
|
||||||
def test_train(self):
|
def test_train(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.check_training(*config_and_inputs)
|
self.model_tester.check_training(*config_and_inputs)
|
||||||
@@ -402,6 +415,10 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.check_ctc_loss(*config_and_inputs)
|
self.model_tester.check_ctc_loss(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_labels_out_of_vocab(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
||||||
|
|
||||||
def test_train(self):
|
def test_train(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.check_training(*config_and_inputs)
|
self.model_tester.check_training(*config_and_inputs)
|
||||||
|
|||||||
@@ -18,6 +18,8 @@
|
|||||||
import math
|
import math
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
|
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
|
||||||
@@ -218,6 +220,20 @@ class Wav2Vec2ModelTester:
|
|||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
def check_labels_out_of_vocab(self, config, input_values, *args):
|
||||||
|
model = Wav2Vec2ForCTC(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
input_values = input_values[:3]
|
||||||
|
|
||||||
|
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||||
|
max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
|
||||||
|
labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size + 100)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
model(input_values, labels=labels)
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config, input_values, attention_mask = self.prepare_config_and_inputs()
|
config, input_values, attention_mask = self.prepare_config_and_inputs()
|
||||||
inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
|
inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
|
||||||
@@ -252,6 +268,10 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.check_training(*config_and_inputs)
|
self.model_tester.check_training(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_labels_out_of_vocab(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
||||||
|
|
||||||
# Wav2Vec2 has no inputs_embeds
|
# Wav2Vec2 has no inputs_embeds
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
pass
|
pass
|
||||||
@@ -392,6 +412,10 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.check_training(*config_and_inputs)
|
self.model_tester.check_training(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_labels_out_of_vocab(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
||||||
|
|
||||||
# Wav2Vec2 has no inputs_embeds
|
# Wav2Vec2 has no inputs_embeds
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user