Add Wav2Vec2 & Hubert ForSequenceClassification (#13153)

* Add hubert classifier + tests

* Add hubert classifier + tests

* Dummies for all classification tests

* Wav2Vec2 classifier + ER test

* Fix hubert integration tests

* Add hubert IC

* Pass tests for all classification tasks on Hubert

* Pass all tests + copies

* Move models to the SUPERB org
This commit is contained in:
Anton Lozhkov
2021-08-27 20:52:51 +03:00
committed by GitHub
parent 2bef3433e5
commit b6f332ecaf
16 changed files with 823 additions and 36 deletions

View File

@@ -31,7 +31,13 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
if is_torch_available():
import torch
from transformers import HubertForCTC, HubertModel, Wav2Vec2Processor
from transformers import (
HubertForCTC,
HubertForSequenceClassification,
HubertModel,
Wav2Vec2FeatureExtractor,
Wav2Vec2Processor,
)
from transformers.models.hubert.modeling_hubert import _compute_mask_indices
@@ -187,7 +193,32 @@ class HubertModelTester:
self.parent.assertTrue(isinstance(sum_loss, float))
self.parent.assertTrue(isinstance(mean_loss, float))
def check_training(self, config, input_values, *args):
def check_seq_classifier_loss(self, config, input_values, *args):
model = HubertForSequenceClassification(config=config)
model.to(torch_device)
# make sure that dropout is disabled
model.eval()
input_values = input_values[:3]
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
attention_mask[i, input_lengths[i] :] = 0
masked_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
unmasked_loss = model(input_values, labels=labels).loss.item()
self.parent.assertTrue(isinstance(masked_loss, float))
self.parent.assertTrue(isinstance(unmasked_loss, float))
self.parent.assertTrue(masked_loss != unmasked_loss)
def check_ctc_training(self, config, input_values, *args):
config.ctc_zero_infinity = True
model = HubertForCTC(config=config)
model.to(torch_device)
@@ -216,6 +247,29 @@ class HubertModelTester:
loss.backward()
def check_seq_classifier_training(self, config, input_values, *args):
config.ctc_zero_infinity = True
model = HubertForSequenceClassification(config=config)
model.to(torch_device)
model.train()
# freeze everything but the classification head
model.freeze_base_model()
input_values = input_values[:3]
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
loss = model(input_values, labels=labels).loss
self.parent.assertFalse(torch.isinf(loss).item())
loss.backward()
def check_labels_out_of_vocab(self, config, input_values, *args):
model = HubertForCTC(config)
model.to(torch_device)
@@ -238,7 +292,7 @@ class HubertModelTester:
@require_torch
class HubertModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (HubertForCTC, HubertModel) if is_torch_available() else ()
all_model_classes = (HubertForCTC, HubertForSequenceClassification, HubertModel) if is_torch_available() else ()
test_pruning = False
test_headmasking = False
test_torchscript = False
@@ -258,9 +312,17 @@ class HubertModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_loss(*config_and_inputs)
def test_train(self):
def test_seq_classifier_loss_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_training(*config_and_inputs)
self.model_tester.check_seq_classifier_loss(*config_and_inputs)
def test_ctc_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_training(*config_and_inputs)
def test_seq_classifier_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_seq_classifier_training(*config_and_inputs)
def test_labels_out_of_vocab(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
@@ -371,7 +433,7 @@ class HubertModelTest(ModelTesterMixin, unittest.TestCase):
@require_torch
class HubertRobustModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (HubertForCTC, HubertModel) if is_torch_available() else ()
all_model_classes = (HubertForCTC, HubertForSequenceClassification, HubertModel) if is_torch_available() else ()
test_pruning = False
test_headmasking = False
test_torchscript = False
@@ -397,9 +459,17 @@ class HubertRobustModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_loss(*config_and_inputs)
def test_train(self):
def test_seq_classifier_loss_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_training(*config_and_inputs)
self.model_tester.check_seq_classifier_loss(*config_and_inputs)
def test_ctc_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_training(*config_and_inputs)
def test_seq_classifier_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_seq_classifier_training(*config_and_inputs)
def test_labels_out_of_vocab(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
@@ -557,6 +627,13 @@ class HubertModelIntegrationTest(unittest.TestCase):
return ds["speech"][:num_samples]
def _load_superb(self, task, num_samples):
from datasets import load_dataset
ds = load_dataset("anton-l/superb_dummy", task, split="test")
return ds[:num_samples]
def test_inference_ctc_batched(self):
model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").to(torch_device)
processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft", do_lower_case=True)
@@ -579,3 +656,95 @@ class HubertModelIntegrationTest(unittest.TestCase):
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
def test_inference_keyword_spotting(self):
model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-ks").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-ks")
input_data = self._load_superb("ks", 4)
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
with torch.no_grad():
outputs = model(input_values, attention_mask=attention_mask)
predicted_logits, predicted_ids = torch.max(outputs.logits, dim=-1)
expected_labels = [2, 6, 10, 9]
# s3prl logits for the same batch
expected_logits = torch.tensor([7.6692, 17.7795, 11.1562, 11.8232], device=torch_device)
self.assertListEqual(predicted_ids.tolist(), expected_labels)
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2))
def test_inference_intent_classification(self):
model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-ic").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-ic")
input_data = self._load_superb("ic", 4)
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
with torch.no_grad():
outputs = model(input_values, attention_mask=attention_mask)
predicted_logits_action, predicted_ids_action = torch.max(outputs.logits[:, :6], dim=-1)
predicted_logits_object, predicted_ids_object = torch.max(outputs.logits[:, 6:20], dim=-1)
predicted_logits_location, predicted_ids_location = torch.max(outputs.logits[:, 20:24], dim=-1)
expected_labels_action = [1, 0, 4, 3]
expected_logits_action = torch.tensor([5.9052, 12.5865, 4.4840, 10.0240], device=torch_device)
expected_labels_object = [1, 10, 3, 4]
expected_logits_object = torch.tensor([5.5316, 11.7946, 8.1672, 23.2415], device=torch_device)
expected_labels_location = [0, 0, 0, 1]
expected_logits_location = torch.tensor([5.2053, 8.9577, 10.0447, 8.1481], device=torch_device)
self.assertListEqual(predicted_ids_action.tolist(), expected_labels_action)
self.assertListEqual(predicted_ids_object.tolist(), expected_labels_object)
self.assertListEqual(predicted_ids_location.tolist(), expected_labels_location)
# TODO: lower the tolerance after merging the padding fix https://github.com/pytorch/fairseq/pull/3572
self.assertTrue(torch.allclose(predicted_logits_action, expected_logits_action, atol=3e-1))
self.assertTrue(torch.allclose(predicted_logits_object, expected_logits_object, atol=3e-1))
self.assertTrue(torch.allclose(predicted_logits_location, expected_logits_location, atol=3e-1))
def test_inference_speaker_identification(self):
model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-sid").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-sid")
input_data = self._load_superb("si", 4)
output_logits = []
with torch.no_grad():
for example in input_data["speech"]:
input = processor(example, return_tensors="pt", padding=True)
output = model(input.input_values.to(torch_device), attention_mask=None)
output_logits.append(output.logits[0])
output_logits = torch.stack(output_logits)
predicted_logits, predicted_ids = torch.max(output_logits, dim=-1)
expected_labels = [5, 1, 1, 3]
# s3prl logits for the same batch
expected_logits = torch.tensor([78231.5547, 123166.6094, 122785.4141, 84851.2969], device=torch_device)
self.assertListEqual(predicted_ids.tolist(), expected_labels)
# TODO: lower the tolerance after merging the padding fix https://github.com/pytorch/fairseq/pull/3572
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=10))
def test_inference_emotion_recognition(self):
model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-er").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-er")
input_data = self._load_superb("er", 4)
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
with torch.no_grad():
outputs = model(input_values, attention_mask=attention_mask)
predicted_logits, predicted_ids = torch.max(outputs.logits, dim=-1)
expected_labels = [1, 1, 2, 2]
# s3prl logits for the same batch
expected_logits = torch.tensor([2.8384, 2.3389, 3.8564, 4.5558], device=torch_device)
self.assertListEqual(predicted_ids.tolist(), expected_labels)
# TODO: lower the tolerance after merging the padding fix https://github.com/pytorch/fairseq/pull/3572
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-1))