Add Speaker Diarization and Verification heads (#14723)
* Models * Squashed commit of the following: commit 72278e1e931a16d0879acc77f65762f3364833d0 Author: anton-l <aglozhkov@gmail.com> Date: Fri Dec 10 21:45:08 2021 +0300 * Add unispeech heads * Add sd/sv automodels * Docs cleanup * Fix docstrings * rename xvector classes * examples * Tests cleanup * Style * Better checkpoints for tests * leftover docs * apply review suggestions * Style + init tests * Update unispeech-sat tdnn downsampling
This commit is contained in:
@@ -33,9 +33,11 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
UniSpeechSatForAudioFrameClassification,
|
||||
UniSpeechSatForCTC,
|
||||
UniSpeechSatForPreTraining,
|
||||
UniSpeechSatForSequenceClassification,
|
||||
UniSpeechSatForXVector,
|
||||
UniSpeechSatModel,
|
||||
Wav2Vec2FeatureExtractor,
|
||||
Wav2Vec2Processor,
|
||||
@@ -70,6 +72,10 @@ class UniSpeechSatModelTester:
|
||||
mask_time_length=2,
|
||||
vocab_size=32,
|
||||
do_stable_layer_norm=False,
|
||||
tdnn_dim=(32, 32),
|
||||
tdnn_kernel=(3, 3),
|
||||
tdnn_dilation=(1, 1),
|
||||
xvector_output_dim=32,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
@@ -97,6 +103,10 @@ class UniSpeechSatModelTester:
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
self.mask_time_prob = mask_time_prob
|
||||
self.mask_time_length = mask_time_length
|
||||
self.tdnn_dim = tdnn_dim
|
||||
self.tdnn_kernel = tdnn_kernel
|
||||
self.tdnn_dilation = tdnn_dilation
|
||||
self.xvector_output_dim = xvector_output_dim
|
||||
self.scope = scope
|
||||
|
||||
output_seq_length = self.seq_length
|
||||
@@ -135,6 +145,10 @@ class UniSpeechSatModelTester:
|
||||
hidden_act=self.hidden_act,
|
||||
initializer_range=self.initializer_range,
|
||||
vocab_size=self.vocab_size,
|
||||
tdnn_dim=self.tdnn_dim,
|
||||
tdnn_kernel=self.tdnn_kernel,
|
||||
tdnn_dilation=self.tdnn_dilation,
|
||||
xvector_output_dim=self.xvector_output_dim,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_values, attention_mask):
|
||||
@@ -277,6 +291,30 @@ class UniSpeechSatModelTester:
|
||||
|
||||
loss.backward()
|
||||
|
||||
def check_xvector_training(self, config, *args):
|
||||
config.ctc_zero_infinity = True
|
||||
model = UniSpeechSatForXVector(config=config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
# freeze everything but the classification head
|
||||
model.freeze_base_model()
|
||||
|
||||
# use a longer sequence length to account for TDNN temporal downsampling
|
||||
input_values = floats_tensor([self.batch_size, self.seq_length * 2], self.vocab_size)
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
|
||||
loss = model(input_values, labels=labels).loss
|
||||
self.parent.assertFalse(torch.isinf(loss).item())
|
||||
|
||||
loss.backward()
|
||||
|
||||
def check_labels_out_of_vocab(self, config, input_values, *args):
|
||||
model = UniSpeechSatForCTC(config)
|
||||
model.to(torch_device)
|
||||
@@ -300,7 +338,14 @@ class UniSpeechSatModelTester:
|
||||
@require_torch
|
||||
class UniSpeechSatModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(UniSpeechSatForCTC, UniSpeechSatForPreTraining, UniSpeechSatModel, UniSpeechSatForSequenceClassification)
|
||||
(
|
||||
UniSpeechSatForCTC,
|
||||
UniSpeechSatForPreTraining,
|
||||
UniSpeechSatModel,
|
||||
UniSpeechSatForSequenceClassification,
|
||||
UniSpeechSatForAudioFrameClassification,
|
||||
UniSpeechSatForXVector,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
@@ -335,6 +380,10 @@ class UniSpeechSatModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_seq_classifier_training(*config_and_inputs)
|
||||
|
||||
def test_xvector_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_xvector_training(*config_and_inputs)
|
||||
|
||||
def test_labels_out_of_vocab(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
||||
@@ -417,6 +466,7 @@ class UniSpeechSatModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"feature_projection.projection.weight",
|
||||
"feature_projection.projection.bias",
|
||||
"label_embeddings_concat",
|
||||
"objective.weight",
|
||||
]
|
||||
if param.requires_grad:
|
||||
if any([x in name for x in uniform_init_parms]):
|
||||
@@ -623,6 +673,7 @@ class UniSpeechSatRobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"feature_projection.projection.weight",
|
||||
"feature_projection.projection.bias",
|
||||
"label_embeddings_concat",
|
||||
"objective.weight",
|
||||
]
|
||||
if param.requires_grad:
|
||||
if any([x in name for x in uniform_init_parms]):
|
||||
@@ -811,3 +862,56 @@ class UniSpeechSatModelIntegrationTest(unittest.TestCase):
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.last_hidden_state[:, :2, -2:], expected_hidden_states_slice, atol=1e-3))
|
||||
|
||||
def test_inference_diarization(self):
|
||||
model = UniSpeechSatForAudioFrameClassification.from_pretrained("anton-l/unispeech-sat-base-plus-sd").to(
|
||||
torch_device
|
||||
)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/unispeech-sat-base-plus-sd")
|
||||
input_data = self._load_superb("sd", 4)
|
||||
inputs = processor(input_data["speech"], return_tensors="pt", padding=True, sampling_rate=16_000)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
with torch.no_grad():
|
||||
outputs = model(input_values, attention_mask=attention_mask)
|
||||
# labels is a one-hot array of shape (num_frames, num_speakers)
|
||||
labels = (outputs.logits > 0).long()
|
||||
|
||||
# s3prl logits for the same batch
|
||||
expected_logits = torch.tensor(
|
||||
[
|
||||
[[-5.6119, -5.5845], [-3.7772, -5.4824], [-3.6914, -5.1619], [-4.7560, -5.0496]],
|
||||
[[-6.3785, -4.8365], [-5.5863, -5.4149], [-5.5639, -4.8469], [-6.1511, -4.0052]],
|
||||
[[-6.0355, -3.7414], [-5.5968, -4.8061], [-5.4620, -4.7310], [-5.5864, -4.6078]],
|
||||
[[-5.9493, -4.8963], [-4.4050, -5.4476], [-4.1755, -5.1395], [-4.0272, -4.3705]],
|
||||
],
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertEqual(labels[0, :, 0].sum(), 270)
|
||||
self.assertEqual(labels[0, :, 1].sum(), 647)
|
||||
self.assertTrue(torch.allclose(outputs.logits[:, :4], expected_logits, atol=1e-3))
|
||||
|
||||
def test_inference_speaker_verification(self):
|
||||
model = UniSpeechSatForXVector.from_pretrained("anton-l/unispeech-sat-base-plus-sv").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/unispeech-sat-base-plus-sv")
|
||||
input_data = self._load_superb("si", 4)
|
||||
|
||||
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
|
||||
labels = torch.tensor([5, 1, 1, 3], device=torch_device).T
|
||||
|
||||
with torch.no_grad():
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
outputs = model(input_values, attention_mask=attention_mask, labels=labels)
|
||||
embeddings = torch.nn.functional.normalize(outputs.embeddings, dim=-1)
|
||||
|
||||
cosine_sim = torch.nn.CosineSimilarity(dim=-1)
|
||||
# id10002 vs id10002
|
||||
self.assertAlmostEqual(cosine_sim(embeddings[1], embeddings[2]).item(), 0.9671, 3)
|
||||
# id10006 vs id10002
|
||||
self.assertAlmostEqual(cosine_sim(embeddings[0], embeddings[1]).item(), 0.4941, 3)
|
||||
# id10002 vs id10004
|
||||
self.assertAlmostEqual(cosine_sim(embeddings[2], embeddings[3]).item(), 0.5616, 3)
|
||||
|
||||
self.assertAlmostEqual(outputs.loss.item(), 18.5925, 3)
|
||||
|
||||
@@ -44,10 +44,12 @@ if is_torch_available():
|
||||
|
||||
from transformers import (
|
||||
Wav2Vec2FeatureExtractor,
|
||||
Wav2Vec2ForAudioFrameClassification,
|
||||
Wav2Vec2ForCTC,
|
||||
Wav2Vec2ForMaskedLM,
|
||||
Wav2Vec2ForPreTraining,
|
||||
Wav2Vec2ForSequenceClassification,
|
||||
Wav2Vec2ForXVector,
|
||||
Wav2Vec2Model,
|
||||
Wav2Vec2Processor,
|
||||
)
|
||||
@@ -96,6 +98,10 @@ class Wav2Vec2ModelTester:
|
||||
do_stable_layer_norm=False,
|
||||
num_adapter_layers=1,
|
||||
adapter_stride=2,
|
||||
tdnn_dim=(32, 32),
|
||||
tdnn_kernel=(5, 3),
|
||||
tdnn_dilation=(1, 2),
|
||||
xvector_output_dim=32,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
@@ -126,6 +132,10 @@ class Wav2Vec2ModelTester:
|
||||
self.mask_time_prob = mask_time_prob
|
||||
self.mask_time_length = mask_time_length
|
||||
self.scope = scope
|
||||
self.tdnn_dim = tdnn_dim
|
||||
self.tdnn_kernel = tdnn_kernel
|
||||
self.tdnn_dilation = tdnn_dilation
|
||||
self.xvector_output_dim = xvector_output_dim
|
||||
|
||||
output_seq_length = self.seq_length
|
||||
for kernel, stride in zip(self.conv_kernel, self.conv_stride):
|
||||
@@ -168,6 +178,10 @@ class Wav2Vec2ModelTester:
|
||||
vocab_size=self.vocab_size,
|
||||
num_adapter_layers=self.num_adapter_layers,
|
||||
adapter_stride=self.adapter_stride,
|
||||
tdnn_dim=self.tdnn_dim,
|
||||
tdnn_kernel=self.tdnn_kernel,
|
||||
tdnn_dilation=self.tdnn_dilation,
|
||||
xvector_output_dim=self.xvector_output_dim,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_values, attention_mask):
|
||||
@@ -332,6 +346,29 @@ class Wav2Vec2ModelTester:
|
||||
|
||||
loss.backward()
|
||||
|
||||
def check_xvector_training(self, config, input_values, *args):
|
||||
config.ctc_zero_infinity = True
|
||||
model = Wav2Vec2ForXVector(config=config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
# freeze everything but the classification head
|
||||
model.freeze_base_model()
|
||||
|
||||
input_values = input_values[:3]
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
|
||||
loss = model(input_values, labels=labels).loss
|
||||
self.parent.assertFalse(torch.isinf(loss).item())
|
||||
|
||||
loss.backward()
|
||||
|
||||
def check_labels_out_of_vocab(self, config, input_values, *args):
|
||||
model = Wav2Vec2ForCTC(config)
|
||||
model.to(torch_device)
|
||||
@@ -398,6 +435,10 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_seq_classifier_training(*config_and_inputs)
|
||||
|
||||
def test_xvector_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_xvector_training(*config_and_inputs)
|
||||
|
||||
def test_labels_out_of_vocab(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
||||
@@ -489,6 +530,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"project_q.bias",
|
||||
"feature_projection.projection.weight",
|
||||
"feature_projection.projection.bias",
|
||||
"objective.weight",
|
||||
]
|
||||
if param.requires_grad:
|
||||
if any([x in name for x in uniform_init_parms]):
|
||||
@@ -573,7 +615,15 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
@require_torch
|
||||
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForSequenceClassification, Wav2Vec2ForPreTraining)
|
||||
(
|
||||
Wav2Vec2ForCTC,
|
||||
Wav2Vec2Model,
|
||||
Wav2Vec2ForMaskedLM,
|
||||
Wav2Vec2ForSequenceClassification,
|
||||
Wav2Vec2ForPreTraining,
|
||||
Wav2Vec2ForAudioFrameClassification,
|
||||
Wav2Vec2ForXVector,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
@@ -622,6 +672,10 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_seq_classifier_training(*config_and_inputs)
|
||||
|
||||
def test_xvector_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_xvector_training(*config_and_inputs)
|
||||
|
||||
def test_labels_out_of_vocab(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
||||
@@ -703,6 +757,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"project_q.bias",
|
||||
"feature_projection.projection.weight",
|
||||
"feature_projection.projection.bias",
|
||||
"objective.weight",
|
||||
]
|
||||
if param.requires_grad:
|
||||
if any([x in name for x in uniform_init_parms]):
|
||||
@@ -1369,3 +1424,54 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
transcription = processor.batch_decode(logits.cpu().numpy()).text
|
||||
|
||||
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||
|
||||
def test_inference_diarization(self):
|
||||
model = Wav2Vec2ForAudioFrameClassification.from_pretrained("anton-l/wav2vec2-base-superb-sd").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/wav2vec2-base-superb-sd")
|
||||
input_data = self._load_superb("sd", 4)
|
||||
inputs = processor(input_data["speech"], return_tensors="pt", padding=True, sampling_rate=16_000)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
with torch.no_grad():
|
||||
outputs = model(input_values, attention_mask=attention_mask)
|
||||
# labels is a one-hot array of shape (num_frames, num_speakers)
|
||||
labels = (outputs.logits > 0).long()
|
||||
|
||||
# s3prl logits for the same batch
|
||||
expected_logits = torch.tensor(
|
||||
[
|
||||
[[-5.2807, -5.1272], [-5.4059, -4.7757], [-5.2764, -4.9621], [-5.0117, -4.5851]],
|
||||
[[-1.7643, -0.5462], [-1.7369, -0.2649], [-1.5066, -0.6200], [-4.5703, -2.4863]],
|
||||
[[-0.8656, -0.4783], [-0.8899, -0.3289], [-0.9267, -0.5781], [-0.7817, -0.4619]],
|
||||
[[-4.8625, -2.5316], [-5.2339, -2.2155], [-4.9835, -2.0344], [-4.4727, -1.8421]],
|
||||
],
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertEqual(labels[0, :, 0].sum(), 555)
|
||||
self.assertEqual(labels[0, :, 1].sum(), 299)
|
||||
self.assertTrue(torch.allclose(outputs.logits[:, :4], expected_logits, atol=1e-3))
|
||||
|
||||
def test_inference_speaker_verification(self):
|
||||
model = Wav2Vec2ForXVector.from_pretrained("anton-l/wav2vec2-base-superb-sv").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/wav2vec2-base-superb-sv")
|
||||
input_data = self._load_superb("si", 4)
|
||||
|
||||
inputs = processor(input_data["speech"], return_tensors="pt", padding=True, sampling_rate=16_000)
|
||||
labels = torch.tensor([5, 1, 1, 3], device=torch_device).T
|
||||
|
||||
with torch.no_grad():
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
outputs = model(input_values, attention_mask=attention_mask, labels=labels)
|
||||
embeddings = torch.nn.functional.normalize(outputs.embeddings, dim=-1).cpu()
|
||||
|
||||
cosine_sim = torch.nn.CosineSimilarity(dim=-1)
|
||||
# id10002 vs id10002
|
||||
self.assertAlmostEqual(cosine_sim(embeddings[1], embeddings[2]).numpy(), 0.9758, 3)
|
||||
# id10006 vs id10002
|
||||
self.assertAlmostEqual(cosine_sim(embeddings[0], embeddings[1]).numpy(), 0.7579, 3)
|
||||
# id10002 vs id10004
|
||||
self.assertAlmostEqual(cosine_sim(embeddings[2], embeddings[3]).numpy(), 0.7594, 3)
|
||||
|
||||
self.assertAlmostEqual(outputs.loss.item(), 17.7963, 3)
|
||||
|
||||
Reference in New Issue
Block a user