Add SD and SV heads for WavLM (#14847)

* Add converted heads

* Add dummies
This commit is contained in:
Anton Lozhkov
2021-12-20 16:40:56 +03:00
committed by GitHub
parent cd583bdaa5
commit 3883e3a75e
10 changed files with 573 additions and 7 deletions

View File

@@ -863,10 +863,10 @@ class UniSpeechSatModelIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(outputs.last_hidden_state[:, :2, -2:], expected_hidden_states_slice, atol=1e-3))
def test_inference_diarization(self):
model = UniSpeechSatForAudioFrameClassification.from_pretrained("anton-l/unispeech-sat-base-plus-sd").to(
model = UniSpeechSatForAudioFrameClassification.from_pretrained("microsoft/unispeech-sat-base-plus-sd").to(
torch_device
)
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/unispeech-sat-base-plus-sd")
processor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/unispeech-sat-base-plus-sd")
input_data = self._load_superb("sd", 4)
inputs = processor(input_data["speech"], return_tensors="pt", padding=True, sampling_rate=16_000)
@@ -892,8 +892,8 @@ class UniSpeechSatModelIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(outputs.logits[:, :4], expected_logits, atol=1e-3))
def test_inference_speaker_verification(self):
model = UniSpeechSatForXVector.from_pretrained("anton-l/unispeech-sat-base-plus-sv").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/unispeech-sat-base-plus-sv")
model = UniSpeechSatForXVector.from_pretrained("microsoft/unispeech-sat-base-plus-sv").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/unispeech-sat-base-plus-sv")
input_data = self._load_superb("si", 4)
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)

View File

@@ -31,7 +31,14 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
if is_torch_available():
import torch
from transformers import Wav2Vec2FeatureExtractor, WavLMForCTC, WavLMForSequenceClassification, WavLMModel
from transformers import (
Wav2Vec2FeatureExtractor,
WavLMForAudioFrameClassification,
WavLMForCTC,
WavLMForSequenceClassification,
WavLMForXVector,
WavLMModel,
)
class WavLMModelTester:
@@ -60,6 +67,10 @@ class WavLMModelTester:
initializer_range=0.02,
vocab_size=32,
do_stable_layer_norm=False,
tdnn_dim=(32, 32),
tdnn_kernel=(3, 3),
tdnn_dilation=(1, 1),
xvector_output_dim=32,
scope=None,
):
self.parent = parent
@@ -85,6 +96,10 @@ class WavLMModelTester:
self.initializer_range = initializer_range
self.vocab_size = vocab_size
self.do_stable_layer_norm = do_stable_layer_norm
self.tdnn_dim = tdnn_dim
self.tdnn_kernel = tdnn_kernel
self.tdnn_dilation = tdnn_dilation
self.xvector_output_dim = xvector_output_dim
self.scope = scope
output_seq_length = self.seq_length
@@ -121,6 +136,10 @@ class WavLMModelTester:
hidden_act=self.hidden_act,
initializer_range=self.initializer_range,
vocab_size=self.vocab_size,
tdnn_dim=self.tdnn_dim,
tdnn_kernel=self.tdnn_kernel,
tdnn_dilation=self.tdnn_dilation,
xvector_output_dim=self.xvector_output_dim,
)
def create_and_check_model(self, config, input_values, attention_mask):
@@ -285,7 +304,11 @@ class WavLMModelTester:
@require_torch
class WavLMModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (WavLMForCTC, WavLMModel, WavLMForSequenceClassification) if is_torch_available() else ()
all_model_classes = (
(WavLMForCTC, WavLMModel, WavLMForAudioFrameClassification, WavLMForSequenceClassification, WavLMForXVector)
if is_torch_available()
else ()
)
test_pruning = False
test_headmasking = False
test_torchscript = False
@@ -398,6 +421,7 @@ class WavLMModelTest(ModelTesterMixin, unittest.TestCase):
"feature_projection.projection.bias",
"label_embeddings_concat",
"rel_attn_embed",
"objective.weight",
]
if param.requires_grad:
if any([x in name for x in uniform_init_parms]):
@@ -446,6 +470,11 @@ class WavLMModelIntegrationTest(unittest.TestCase):
return [x["array"] for x in speech_samples]
def _load_superb(self, task, num_samples):
ds = load_dataset("anton-l/superb_dummy", task, split="test")
return ds[:num_samples]
def test_inference_base(self):
model = WavLMModel.from_pretrained("microsoft/wavlm-base-plus").to(torch_device)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
@@ -491,3 +520,54 @@ class WavLMModelIntegrationTest(unittest.TestCase):
[[[0.1612, 0.4314], [0.1690, 0.4344]], [[0.2086, 0.1396], [0.3014, 0.0903]]]
)
self.assertTrue(torch.allclose(hidden_states_slice, EXPECTED_HIDDEN_STATES_SLICE, rtol=1e-2))
def test_inference_diarization(self):
model = WavLMForAudioFrameClassification.from_pretrained("microsoft/wavlm-base-plus-sd").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/wavlm-base-plus-sd")
input_data = self._load_superb("sd", 4)
inputs = processor(input_data["speech"], return_tensors="pt", padding=True, sampling_rate=16_000)
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
with torch.no_grad():
outputs = model(input_values, attention_mask=attention_mask)
# labels is a one-hot array of shape (num_frames, num_speakers)
labels = (outputs.logits > 0).long()
# s3prl logits for the same batch
expected_logits = torch.tensor(
[
[[-5.9566, -8.6554], [-5.7137, -8.9386], [-5.7906, -7.0973], [-5.7829, -5.9999]],
[[-5.2086, -7.7878], [-4.8890, -7.9312], [-4.2004, -3.9101], [-5.4480, -4.6932]],
[[-4.6105, -6.7178], [-5.1930, -6.1635], [-2.6228, -4.1123], [-2.7646, -3.1576]],
[[-4.4477, -7.9206], [-3.9339, -7.3707], [-4.9528, -4.8242], [-3.6921, -2.9687]],
],
device=torch_device,
)
self.assertEqual(labels[0, :, 0].sum(), 258)
self.assertEqual(labels[0, :, 1].sum(), 647)
self.assertTrue(torch.allclose(outputs.logits[:, :4], expected_logits, atol=1e-3))
def test_inference_speaker_verification(self):
model = WavLMForXVector.from_pretrained("microsoft/wavlm-base-plus-sv").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/wavlm-base-plus-sv")
input_data = self._load_superb("si", 4)
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
labels = torch.tensor([5, 1, 1, 3], device=torch_device).T
with torch.no_grad():
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
outputs = model(input_values, attention_mask=attention_mask, labels=labels)
embeddings = torch.nn.functional.normalize(outputs.embeddings, dim=-1)
cosine_sim = torch.nn.CosineSimilarity(dim=-1)
# id10002 vs id10002
self.assertAlmostEqual(cosine_sim(embeddings[1], embeddings[2]).item(), 0.9787, 3)
# id10006 vs id10002
self.assertAlmostEqual(cosine_sim(embeddings[0], embeddings[1]).item(), 0.5064, 3)
# id10002 vs id10004
self.assertAlmostEqual(cosine_sim(embeddings[2], embeddings[3]).item(), 0.4780, 3)
self.assertAlmostEqual(outputs.loss.item(), 18.4154, 3)