Add Wav2Vec2 & Hubert ForSequenceClassification (#13153)

* Add hubert classifier + tests

* Add hubert classifier + tests

* Dummies for all classification tests

* Wav2Vec2 classifier + ER test

* Fix hubert integration tests

* Add hubert IC

* Pass tests for all classification tasks on Hubert

* Pass all tests + copies

* Move models to the SUPERB org
This commit is contained in:
Anton Lozhkov
2021-08-27 20:52:51 +03:00
committed by GitHub
parent 2bef3433e5
commit b6f332ecaf
16 changed files with 823 additions and 36 deletions

View File

@@ -14,7 +14,6 @@
# limitations under the License.
""" Testing suite for the PyTorch Wav2Vec2 model. """
import math
import unittest
@@ -36,6 +35,7 @@ if is_torch_available():
Wav2Vec2ForCTC,
Wav2Vec2ForMaskedLM,
Wav2Vec2ForPreTraining,
Wav2Vec2ForSequenceClassification,
Wav2Vec2Model,
Wav2Vec2Processor,
)
@@ -194,7 +194,32 @@ class Wav2Vec2ModelTester:
self.parent.assertTrue(isinstance(sum_loss, float))
self.parent.assertTrue(isinstance(mean_loss, float))
def check_training(self, config, input_values, *args):
def check_seq_classifier_loss(self, config, input_values, *args):
model = Wav2Vec2ForSequenceClassification(config=config)
model.to(torch_device)
# make sure that dropout is disabled
model.eval()
input_values = input_values[:3]
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
attention_mask[i, input_lengths[i] :] = 0
masked_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
unmasked_loss = model(input_values, labels=labels).loss.item()
self.parent.assertTrue(isinstance(masked_loss, float))
self.parent.assertTrue(isinstance(unmasked_loss, float))
self.parent.assertTrue(masked_loss != unmasked_loss)
def check_ctc_training(self, config, input_values, *args):
config.ctc_zero_infinity = True
model = Wav2Vec2ForCTC(config=config)
model.to(torch_device)
@@ -223,6 +248,29 @@ class Wav2Vec2ModelTester:
loss.backward()
def check_seq_classifier_training(self, config, input_values, *args):
config.ctc_zero_infinity = True
model = Wav2Vec2ForSequenceClassification(config=config)
model.to(torch_device)
model.train()
# freeze everything but the classification head
model.freeze_base_model()
input_values = input_values[:3]
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
loss = model(input_values, labels=labels).loss
self.parent.assertFalse(torch.isinf(loss).item())
loss.backward()
def check_labels_out_of_vocab(self, config, input_values, *args):
model = Wav2Vec2ForCTC(config)
model.to(torch_device)
@@ -246,7 +294,9 @@ class Wav2Vec2ModelTester:
@require_torch
class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining) if is_torch_available() else ()
(Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForSequenceClassification, Wav2Vec2ForPreTraining)
if is_torch_available()
else ()
)
test_pruning = False
test_headmasking = False
@@ -267,9 +317,17 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_loss(*config_and_inputs)
def test_train(self):
def test_seq_classifier_loss_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_training(*config_and_inputs)
self.model_tester.check_seq_classifier_loss(*config_and_inputs)
def test_ctc_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_training(*config_and_inputs)
def test_seq_classifier_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_seq_classifier_training(*config_and_inputs)
def test_labels_out_of_vocab(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
@@ -384,7 +442,9 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
@require_torch
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining) if is_torch_available() else ()
(Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForSequenceClassification, Wav2Vec2ForPreTraining)
if is_torch_available()
else ()
)
test_pruning = False
test_headmasking = False
@@ -411,9 +471,17 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_loss(*config_and_inputs)
def test_train(self):
def test_seq_classifier_loss_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_training(*config_and_inputs)
self.model_tester.check_seq_classifier_loss(*config_and_inputs)
def test_ctc_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_training(*config_and_inputs)
def test_seq_classifier_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_seq_classifier_training(*config_and_inputs)
def test_labels_out_of_vocab(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
@@ -691,6 +759,13 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
return ds["speech"][:num_samples]
def _load_superb(self, task, num_samples):
from datasets import load_dataset
ds = load_dataset("anton-l/superb_dummy", task, split="test")
return ds[:num_samples]
def test_inference_ctc_normal(self):
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
model.to(torch_device)
@@ -795,7 +870,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
# fmt: off
expected_cosine_sim_masked = torch.tensor(
[0.7458, 0.7188, 0.6418, 0.3729, 0.3741, 0.3694, 0.3110, 0.2257, 0.4403, 0.5415, 0.3950, 0.3701, 0.8831, 0.8613, 0.5229, 0.6696, 0.7206, 0.7877, 0.6758, 0.8746, 0.6596, 0.6282, 0.6178, 0.5839, 0.5926, 0.6651, 0.4635, 0.6332, 0.6572, 0.8776, 0.4999, 0.7001, 0.7257, 0.5098, 0.6229, 0.4566, 0.5261, 0.6363, 0.5371, 0.6997],
[0.7458, 0.7188, 0.6418, 0.3729, 0.3741, 0.3694, 0.3110, 0.2257, 0.4403, 0.5415, 0.3950, 0.3701, 0.8831,
0.8613, 0.5229, 0.6696, 0.7206, 0.7877, 0.6758, 0.8746, 0.6596, 0.6282, 0.6178, 0.5839, 0.5926, 0.6651,
0.4635, 0.6332, 0.6572, 0.8776, 0.4999, 0.7001, 0.7257, 0.5098, 0.6229, 0.4566, 0.5261, 0.6363, 0.5371,
0.6997],
device=torch_device,
)
# fmt: on
@@ -913,3 +991,92 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
expected_loss = 62.5170
self.assertTrue(abs(outputs.loss.item() - expected_loss) < 1e-3)
def test_inference_keyword_spotting(self):
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ks").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ks")
input_data = self._load_superb("ks", 4)
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
with torch.no_grad():
outputs = model(input_values, attention_mask=attention_mask)
predicted_logits, predicted_ids = torch.max(outputs.logits, dim=-1)
expected_labels = [7, 6, 10, 9]
# s3prl logits for the same batch
expected_logits = torch.tensor([6.1186, 11.8961, 10.2931, 6.0898], device=torch_device)
self.assertListEqual(predicted_ids.tolist(), expected_labels)
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2))
def test_inference_intent_classification(self):
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ic").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ic")
input_data = self._load_superb("ic", 4)
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
with torch.no_grad():
outputs = model(input_values, attention_mask=attention_mask)
predicted_logits_action, predicted_ids_action = torch.max(outputs.logits[:, :6], dim=-1)
predicted_logits_object, predicted_ids_object = torch.max(outputs.logits[:, 6:20], dim=-1)
predicted_logits_location, predicted_ids_location = torch.max(outputs.logits[:, 20:24], dim=-1)
expected_labels_action = [0, 0, 2, 3]
expected_logits_action = torch.tensor([0.4568, 11.0848, 1.6621, 9.3841], device=torch_device)
expected_labels_object = [3, 10, 3, 4]
expected_logits_object = torch.tensor([1.5322, 10.7094, 5.2469, 22.1318], device=torch_device)
expected_labels_location = [0, 0, 0, 1]
expected_logits_location = torch.tensor([1.5335, 6.5096, 10.5704, 11.0569], device=torch_device)
self.assertListEqual(predicted_ids_action.tolist(), expected_labels_action)
self.assertListEqual(predicted_ids_object.tolist(), expected_labels_object)
self.assertListEqual(predicted_ids_location.tolist(), expected_labels_location)
self.assertTrue(torch.allclose(predicted_logits_action, expected_logits_action, atol=1e-2))
self.assertTrue(torch.allclose(predicted_logits_object, expected_logits_object, atol=1e-2))
self.assertTrue(torch.allclose(predicted_logits_location, expected_logits_location, atol=1e-2))
def test_inference_speaker_identification(self):
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-sid").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-sid")
input_data = self._load_superb("si", 4)
output_logits = []
with torch.no_grad():
for example in input_data["speech"]:
input = processor(example, return_tensors="pt", padding=True)
output = model(input.input_values.to(torch_device), attention_mask=None)
output_logits.append(output.logits[0])
output_logits = torch.stack(output_logits)
predicted_logits, predicted_ids = torch.max(output_logits, dim=-1)
expected_labels = [251, 1, 1, 3]
# s3prl logits for the same batch
expected_logits = torch.tensor([37.5627, 71.6362, 64.2419, 31.7778], device=torch_device)
self.assertListEqual(predicted_ids.tolist(), expected_labels)
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2))
def test_inference_emotion_recognition(self):
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-er").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-er")
input_data = self._load_superb("er", 4)
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
with torch.no_grad():
outputs = model(input_values, attention_mask=attention_mask)
predicted_logits, predicted_ids = torch.max(outputs.logits, dim=-1)
expected_labels = [1, 1, 2, 2]
# s3prl logits for the same batch
expected_logits = torch.tensor([2.1722, 3.0779, 8.0287, 6.6797], device=torch_device)
self.assertListEqual(predicted_ids.tolist(), expected_labels)
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2))