From bc084938f2f268a5f818e3db519932304acbb309 Mon Sep 17 00:00:00 2001 From: Will Rice Date: Tue, 29 Jun 2021 03:57:46 -0400 Subject: [PATCH] Add out of vocabulary error to ASR models (#12288) * Add OOV error to ASR models * Feedback changes --- .../models/hubert/modeling_hubert.py | 3 +++ .../models/wav2vec2/modeling_tf_wav2vec2.py | 4 ++++ .../models/wav2vec2/modeling_wav2vec2.py | 3 +++ tests/test_modeling_hubert.py | 24 +++++++++++++++++++ tests/test_modeling_tf_wav2vec2.py | 17 +++++++++++++ tests/test_modeling_wav2vec2.py | 24 +++++++++++++++++++ 6 files changed, 75 insertions(+) diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index cad377eb66..8154f2fe20 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -1030,6 +1030,9 @@ class HubertForCTC(HubertPreTrainedModel): loss = None if labels is not None: + if labels.max() >= self.config.vocab_size: + raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") + # retrieve loss input_lengths from attention_mask attention_mask = ( attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) diff --git a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py index a7d82f2b32..1517ec2c6c 100644 --- a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py @@ -1571,6 +1571,10 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel): logits = self.lm_head(hidden_states) if labels is not None: + + if tf.reduce_max(labels) >= self.config.vocab_size: + raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") + attention_mask = ( inputs["attention_mask"] if inputs["attention_mask"] is not None diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 87b78c6aee..2f1b4ed991 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -1480,6 +1480,9 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): loss = None if labels is not None: + if labels.max() >= self.config.vocab_size: + raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") + # retrieve loss input_lengths from attention_mask attention_mask = ( attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) diff --git a/tests/test_modeling_hubert.py b/tests/test_modeling_hubert.py index 90fc004393..016c03cefc 100644 --- a/tests/test_modeling_hubert.py +++ b/tests/test_modeling_hubert.py @@ -18,6 +18,8 @@ import math import unittest +import pytest + from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask from transformers import is_torch_available from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device @@ -210,6 +212,20 @@ class HubertModelTester: loss.backward() + def check_labels_out_of_vocab(self, config, input_values, *args): + model = HubertForCTC(config) + model.to(torch_device) + model.train() + + input_values = input_values[:3] + + input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] + max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths)) + labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size + 100) + + with pytest.raises(ValueError): + model(input_values, labels=labels) + def prepare_config_and_inputs_for_common(self): config, input_values, attention_mask = self.prepare_config_and_inputs() inputs_dict = {"input_values": input_values, "attention_mask": attention_mask} @@ -242,6 +258,10 @@ class HubertModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_training(*config_and_inputs) + def test_labels_out_of_vocab(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_labels_out_of_vocab(*config_and_inputs) + # Hubert has no inputs_embeds def test_inputs_embeds(self): pass @@ -377,6 +397,10 @@ class HubertRobustModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_training(*config_and_inputs) + def test_labels_out_of_vocab(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_labels_out_of_vocab(*config_and_inputs) + # Hubert has no inputs_embeds def test_inputs_embeds(self): pass diff --git a/tests/test_modeling_tf_wav2vec2.py b/tests/test_modeling_tf_wav2vec2.py index 47c378cc88..889790c75e 100644 --- a/tests/test_modeling_tf_wav2vec2.py +++ b/tests/test_modeling_tf_wav2vec2.py @@ -20,6 +20,7 @@ import math import unittest import numpy as np +import pytest from transformers import Wav2Vec2Config, is_tf_available from transformers.testing_utils import require_datasets, require_soundfile, require_tf, slow @@ -202,6 +203,14 @@ class TFWav2Vec2ModelTester: self.parent.assertFalse(tf.math.is_inf(loss)) + def check_labels_out_of_vocab(self, config, input_values, *args): + model = TFWav2Vec2ForCTC(config) + input_lengths = tf.constant([input_values.shape[-1] // i for i in [4, 2, 1]]) + max_length_labels = model.wav2vec2._get_feat_extract_output_lengths(input_lengths) + labels = ids_tensor((input_values.shape[0], min(max_length_labels) - 1), model.config.vocab_size + 100) + with pytest.raises(ValueError): + model(input_values, labels=labels) + def prepare_config_and_inputs_for_common(self): config, input_values, attention_mask = self.prepare_config_and_inputs() inputs_dict = {"input_values": input_values, "attention_mask": attention_mask} @@ -288,6 +297,10 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_ctc_loss(*config_and_inputs) + def test_labels_out_of_vocab(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_labels_out_of_vocab(*config_and_inputs) + def test_train(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_training(*config_and_inputs) @@ -402,6 +415,10 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_ctc_loss(*config_and_inputs) + def test_labels_out_of_vocab(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_labels_out_of_vocab(*config_and_inputs) + def test_train(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_training(*config_and_inputs) diff --git a/tests/test_modeling_wav2vec2.py b/tests/test_modeling_wav2vec2.py index f9fa91a476..214349ea86 100644 --- a/tests/test_modeling_wav2vec2.py +++ b/tests/test_modeling_wav2vec2.py @@ -18,6 +18,8 @@ import math import unittest +import pytest + from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask from transformers import is_torch_available from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device @@ -218,6 +220,20 @@ class Wav2Vec2ModelTester: loss.backward() + def check_labels_out_of_vocab(self, config, input_values, *args): + model = Wav2Vec2ForCTC(config) + model.to(torch_device) + model.train() + + input_values = input_values[:3] + + input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] + max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths)) + labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size + 100) + + with pytest.raises(ValueError): + model(input_values, labels=labels) + def prepare_config_and_inputs_for_common(self): config, input_values, attention_mask = self.prepare_config_and_inputs() inputs_dict = {"input_values": input_values, "attention_mask": attention_mask} @@ -252,6 +268,10 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_training(*config_and_inputs) + def test_labels_out_of_vocab(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_labels_out_of_vocab(*config_and_inputs) + # Wav2Vec2 has no inputs_embeds def test_inputs_embeds(self): pass @@ -392,6 +412,10 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_training(*config_and_inputs) + def test_labels_out_of_vocab(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_labels_out_of_vocab(*config_and_inputs) + # Wav2Vec2 has no inputs_embeds def test_inputs_embeds(self): pass