Add Wav2Vec2 & Hubert ForSequenceClassification (#13153)
* Add hubert classifier + tests * Add hubert classifier + tests * Dummies for all classification tests * Wav2Vec2 classifier + ER test * Fix hubert integration tests * Add hubert IC * Pass tests for all classification tasks on Hubert * Pass all tests + copies * Move models to the SUPERB org
This commit is contained in:
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch Wav2Vec2 model. """
|
||||
|
||||
|
||||
import math
|
||||
import unittest
|
||||
|
||||
@@ -36,6 +35,7 @@ if is_torch_available():
|
||||
Wav2Vec2ForCTC,
|
||||
Wav2Vec2ForMaskedLM,
|
||||
Wav2Vec2ForPreTraining,
|
||||
Wav2Vec2ForSequenceClassification,
|
||||
Wav2Vec2Model,
|
||||
Wav2Vec2Processor,
|
||||
)
|
||||
@@ -194,7 +194,32 @@ class Wav2Vec2ModelTester:
|
||||
self.parent.assertTrue(isinstance(sum_loss, float))
|
||||
self.parent.assertTrue(isinstance(mean_loss, float))
|
||||
|
||||
def check_training(self, config, input_values, *args):
|
||||
def check_seq_classifier_loss(self, config, input_values, *args):
|
||||
model = Wav2Vec2ForSequenceClassification(config=config)
|
||||
model.to(torch_device)
|
||||
|
||||
# make sure that dropout is disabled
|
||||
model.eval()
|
||||
|
||||
input_values = input_values[:3]
|
||||
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
attention_mask[i, input_lengths[i] :] = 0
|
||||
|
||||
masked_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
|
||||
unmasked_loss = model(input_values, labels=labels).loss.item()
|
||||
|
||||
self.parent.assertTrue(isinstance(masked_loss, float))
|
||||
self.parent.assertTrue(isinstance(unmasked_loss, float))
|
||||
self.parent.assertTrue(masked_loss != unmasked_loss)
|
||||
|
||||
def check_ctc_training(self, config, input_values, *args):
|
||||
config.ctc_zero_infinity = True
|
||||
model = Wav2Vec2ForCTC(config=config)
|
||||
model.to(torch_device)
|
||||
@@ -223,6 +248,29 @@ class Wav2Vec2ModelTester:
|
||||
|
||||
loss.backward()
|
||||
|
||||
def check_seq_classifier_training(self, config, input_values, *args):
|
||||
config.ctc_zero_infinity = True
|
||||
model = Wav2Vec2ForSequenceClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
# freeze everything but the classification head
|
||||
model.freeze_base_model()
|
||||
|
||||
input_values = input_values[:3]
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
|
||||
loss = model(input_values, labels=labels).loss
|
||||
self.parent.assertFalse(torch.isinf(loss).item())
|
||||
|
||||
loss.backward()
|
||||
|
||||
def check_labels_out_of_vocab(self, config, input_values, *args):
|
||||
model = Wav2Vec2ForCTC(config)
|
||||
model.to(torch_device)
|
||||
@@ -246,7 +294,9 @@ class Wav2Vec2ModelTester:
|
||||
@require_torch
|
||||
class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining) if is_torch_available() else ()
|
||||
(Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForSequenceClassification, Wav2Vec2ForPreTraining)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
@@ -267,9 +317,17 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_ctc_loss(*config_and_inputs)
|
||||
|
||||
def test_train(self):
|
||||
def test_seq_classifier_loss_inference(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_training(*config_and_inputs)
|
||||
self.model_tester.check_seq_classifier_loss(*config_and_inputs)
|
||||
|
||||
def test_ctc_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_ctc_training(*config_and_inputs)
|
||||
|
||||
def test_seq_classifier_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_seq_classifier_training(*config_and_inputs)
|
||||
|
||||
def test_labels_out_of_vocab(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
@@ -384,7 +442,9 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
@require_torch
|
||||
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining) if is_torch_available() else ()
|
||||
(Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForSequenceClassification, Wav2Vec2ForPreTraining)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
@@ -411,9 +471,17 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_ctc_loss(*config_and_inputs)
|
||||
|
||||
def test_train(self):
|
||||
def test_seq_classifier_loss_inference(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_training(*config_and_inputs)
|
||||
self.model_tester.check_seq_classifier_loss(*config_and_inputs)
|
||||
|
||||
def test_ctc_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_ctc_training(*config_and_inputs)
|
||||
|
||||
def test_seq_classifier_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_seq_classifier_training(*config_and_inputs)
|
||||
|
||||
def test_labels_out_of_vocab(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
@@ -691,6 +759,13 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
return ds["speech"][:num_samples]
|
||||
|
||||
def _load_superb(self, task, num_samples):
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("anton-l/superb_dummy", task, split="test")
|
||||
|
||||
return ds[:num_samples]
|
||||
|
||||
def test_inference_ctc_normal(self):
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
model.to(torch_device)
|
||||
@@ -795,7 +870,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
# fmt: off
|
||||
expected_cosine_sim_masked = torch.tensor(
|
||||
[0.7458, 0.7188, 0.6418, 0.3729, 0.3741, 0.3694, 0.3110, 0.2257, 0.4403, 0.5415, 0.3950, 0.3701, 0.8831, 0.8613, 0.5229, 0.6696, 0.7206, 0.7877, 0.6758, 0.8746, 0.6596, 0.6282, 0.6178, 0.5839, 0.5926, 0.6651, 0.4635, 0.6332, 0.6572, 0.8776, 0.4999, 0.7001, 0.7257, 0.5098, 0.6229, 0.4566, 0.5261, 0.6363, 0.5371, 0.6997],
|
||||
[0.7458, 0.7188, 0.6418, 0.3729, 0.3741, 0.3694, 0.3110, 0.2257, 0.4403, 0.5415, 0.3950, 0.3701, 0.8831,
|
||||
0.8613, 0.5229, 0.6696, 0.7206, 0.7877, 0.6758, 0.8746, 0.6596, 0.6282, 0.6178, 0.5839, 0.5926, 0.6651,
|
||||
0.4635, 0.6332, 0.6572, 0.8776, 0.4999, 0.7001, 0.7257, 0.5098, 0.6229, 0.4566, 0.5261, 0.6363, 0.5371,
|
||||
0.6997],
|
||||
device=torch_device,
|
||||
)
|
||||
# fmt: on
|
||||
@@ -913,3 +991,92 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
expected_loss = 62.5170
|
||||
|
||||
self.assertTrue(abs(outputs.loss.item() - expected_loss) < 1e-3)
|
||||
|
||||
def test_inference_keyword_spotting(self):
|
||||
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ks").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ks")
|
||||
input_data = self._load_superb("ks", 4)
|
||||
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
with torch.no_grad():
|
||||
outputs = model(input_values, attention_mask=attention_mask)
|
||||
predicted_logits, predicted_ids = torch.max(outputs.logits, dim=-1)
|
||||
|
||||
expected_labels = [7, 6, 10, 9]
|
||||
# s3prl logits for the same batch
|
||||
expected_logits = torch.tensor([6.1186, 11.8961, 10.2931, 6.0898], device=torch_device)
|
||||
|
||||
self.assertListEqual(predicted_ids.tolist(), expected_labels)
|
||||
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2))
|
||||
|
||||
def test_inference_intent_classification(self):
|
||||
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ic").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ic")
|
||||
input_data = self._load_superb("ic", 4)
|
||||
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
with torch.no_grad():
|
||||
outputs = model(input_values, attention_mask=attention_mask)
|
||||
|
||||
predicted_logits_action, predicted_ids_action = torch.max(outputs.logits[:, :6], dim=-1)
|
||||
predicted_logits_object, predicted_ids_object = torch.max(outputs.logits[:, 6:20], dim=-1)
|
||||
predicted_logits_location, predicted_ids_location = torch.max(outputs.logits[:, 20:24], dim=-1)
|
||||
|
||||
expected_labels_action = [0, 0, 2, 3]
|
||||
expected_logits_action = torch.tensor([0.4568, 11.0848, 1.6621, 9.3841], device=torch_device)
|
||||
expected_labels_object = [3, 10, 3, 4]
|
||||
expected_logits_object = torch.tensor([1.5322, 10.7094, 5.2469, 22.1318], device=torch_device)
|
||||
expected_labels_location = [0, 0, 0, 1]
|
||||
expected_logits_location = torch.tensor([1.5335, 6.5096, 10.5704, 11.0569], device=torch_device)
|
||||
|
||||
self.assertListEqual(predicted_ids_action.tolist(), expected_labels_action)
|
||||
self.assertListEqual(predicted_ids_object.tolist(), expected_labels_object)
|
||||
self.assertListEqual(predicted_ids_location.tolist(), expected_labels_location)
|
||||
|
||||
self.assertTrue(torch.allclose(predicted_logits_action, expected_logits_action, atol=1e-2))
|
||||
self.assertTrue(torch.allclose(predicted_logits_object, expected_logits_object, atol=1e-2))
|
||||
self.assertTrue(torch.allclose(predicted_logits_location, expected_logits_location, atol=1e-2))
|
||||
|
||||
def test_inference_speaker_identification(self):
|
||||
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-sid").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-sid")
|
||||
input_data = self._load_superb("si", 4)
|
||||
|
||||
output_logits = []
|
||||
with torch.no_grad():
|
||||
for example in input_data["speech"]:
|
||||
input = processor(example, return_tensors="pt", padding=True)
|
||||
output = model(input.input_values.to(torch_device), attention_mask=None)
|
||||
output_logits.append(output.logits[0])
|
||||
output_logits = torch.stack(output_logits)
|
||||
predicted_logits, predicted_ids = torch.max(output_logits, dim=-1)
|
||||
|
||||
expected_labels = [251, 1, 1, 3]
|
||||
# s3prl logits for the same batch
|
||||
expected_logits = torch.tensor([37.5627, 71.6362, 64.2419, 31.7778], device=torch_device)
|
||||
|
||||
self.assertListEqual(predicted_ids.tolist(), expected_labels)
|
||||
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2))
|
||||
|
||||
def test_inference_emotion_recognition(self):
|
||||
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-er").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-er")
|
||||
input_data = self._load_superb("er", 4)
|
||||
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
with torch.no_grad():
|
||||
outputs = model(input_values, attention_mask=attention_mask)
|
||||
predicted_logits, predicted_ids = torch.max(outputs.logits, dim=-1)
|
||||
|
||||
expected_labels = [1, 1, 2, 2]
|
||||
# s3prl logits for the same batch
|
||||
expected_logits = torch.tensor([2.1722, 3.0779, 8.0287, 6.6797], device=torch_device)
|
||||
|
||||
self.assertListEqual(predicted_ids.tolist(), expected_labels)
|
||||
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2))
|
||||
|
||||
Reference in New Issue
Block a user