Add out of vocabulary error to ASR models (#12288)

* Add OOV error to ASR models

* Feedback changes
This commit is contained in:
Will Rice
2021-06-29 03:57:46 -04:00
committed by GitHub
parent 1fc6817a30
commit bc084938f2
6 changed files with 75 additions and 0 deletions

View File

@@ -18,6 +18,8 @@
import math
import unittest
import pytest
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from transformers import is_torch_available
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
@@ -210,6 +212,20 @@ class HubertModelTester:
loss.backward()
def check_labels_out_of_vocab(self, config, input_values, *args):
model = HubertForCTC(config)
model.to(torch_device)
model.train()
input_values = input_values[:3]
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size + 100)
with pytest.raises(ValueError):
model(input_values, labels=labels)
def prepare_config_and_inputs_for_common(self):
config, input_values, attention_mask = self.prepare_config_and_inputs()
inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
@@ -242,6 +258,10 @@ class HubertModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_training(*config_and_inputs)
def test_labels_out_of_vocab(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
# Hubert has no inputs_embeds
def test_inputs_embeds(self):
pass
@@ -377,6 +397,10 @@ class HubertRobustModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_training(*config_and_inputs)
def test_labels_out_of_vocab(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
# Hubert has no inputs_embeds
def test_inputs_embeds(self):
pass