Add SD and SV heads for WavLM (#14847)
* Add converted heads * Add dummies
This commit is contained in:
@@ -863,10 +863,10 @@ class UniSpeechSatModelIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(outputs.last_hidden_state[:, :2, -2:], expected_hidden_states_slice, atol=1e-3))
|
||||
|
||||
def test_inference_diarization(self):
|
||||
model = UniSpeechSatForAudioFrameClassification.from_pretrained("anton-l/unispeech-sat-base-plus-sd").to(
|
||||
model = UniSpeechSatForAudioFrameClassification.from_pretrained("microsoft/unispeech-sat-base-plus-sd").to(
|
||||
torch_device
|
||||
)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/unispeech-sat-base-plus-sd")
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/unispeech-sat-base-plus-sd")
|
||||
input_data = self._load_superb("sd", 4)
|
||||
inputs = processor(input_data["speech"], return_tensors="pt", padding=True, sampling_rate=16_000)
|
||||
|
||||
@@ -892,8 +892,8 @@ class UniSpeechSatModelIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(outputs.logits[:, :4], expected_logits, atol=1e-3))
|
||||
|
||||
def test_inference_speaker_verification(self):
|
||||
model = UniSpeechSatForXVector.from_pretrained("anton-l/unispeech-sat-base-plus-sv").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/unispeech-sat-base-plus-sv")
|
||||
model = UniSpeechSatForXVector.from_pretrained("microsoft/unispeech-sat-base-plus-sv").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/unispeech-sat-base-plus-sv")
|
||||
input_data = self._load_superb("si", 4)
|
||||
|
||||
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
|
||||
|
||||
@@ -31,7 +31,14 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import Wav2Vec2FeatureExtractor, WavLMForCTC, WavLMForSequenceClassification, WavLMModel
|
||||
from transformers import (
|
||||
Wav2Vec2FeatureExtractor,
|
||||
WavLMForAudioFrameClassification,
|
||||
WavLMForCTC,
|
||||
WavLMForSequenceClassification,
|
||||
WavLMForXVector,
|
||||
WavLMModel,
|
||||
)
|
||||
|
||||
|
||||
class WavLMModelTester:
|
||||
@@ -60,6 +67,10 @@ class WavLMModelTester:
|
||||
initializer_range=0.02,
|
||||
vocab_size=32,
|
||||
do_stable_layer_norm=False,
|
||||
tdnn_dim=(32, 32),
|
||||
tdnn_kernel=(3, 3),
|
||||
tdnn_dilation=(1, 1),
|
||||
xvector_output_dim=32,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
@@ -85,6 +96,10 @@ class WavLMModelTester:
|
||||
self.initializer_range = initializer_range
|
||||
self.vocab_size = vocab_size
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
self.tdnn_dim = tdnn_dim
|
||||
self.tdnn_kernel = tdnn_kernel
|
||||
self.tdnn_dilation = tdnn_dilation
|
||||
self.xvector_output_dim = xvector_output_dim
|
||||
self.scope = scope
|
||||
|
||||
output_seq_length = self.seq_length
|
||||
@@ -121,6 +136,10 @@ class WavLMModelTester:
|
||||
hidden_act=self.hidden_act,
|
||||
initializer_range=self.initializer_range,
|
||||
vocab_size=self.vocab_size,
|
||||
tdnn_dim=self.tdnn_dim,
|
||||
tdnn_kernel=self.tdnn_kernel,
|
||||
tdnn_dilation=self.tdnn_dilation,
|
||||
xvector_output_dim=self.xvector_output_dim,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_values, attention_mask):
|
||||
@@ -285,7 +304,11 @@ class WavLMModelTester:
|
||||
|
||||
@require_torch
|
||||
class WavLMModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (WavLMForCTC, WavLMModel, WavLMForSequenceClassification) if is_torch_available() else ()
|
||||
all_model_classes = (
|
||||
(WavLMForCTC, WavLMModel, WavLMForAudioFrameClassification, WavLMForSequenceClassification, WavLMForXVector)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_torchscript = False
|
||||
@@ -398,6 +421,7 @@ class WavLMModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"feature_projection.projection.bias",
|
||||
"label_embeddings_concat",
|
||||
"rel_attn_embed",
|
||||
"objective.weight",
|
||||
]
|
||||
if param.requires_grad:
|
||||
if any([x in name for x in uniform_init_parms]):
|
||||
@@ -446,6 +470,11 @@ class WavLMModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
return [x["array"] for x in speech_samples]
|
||||
|
||||
def _load_superb(self, task, num_samples):
|
||||
ds = load_dataset("anton-l/superb_dummy", task, split="test")
|
||||
|
||||
return ds[:num_samples]
|
||||
|
||||
def test_inference_base(self):
|
||||
model = WavLMModel.from_pretrained("microsoft/wavlm-base-plus").to(torch_device)
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
@@ -491,3 +520,54 @@ class WavLMModelIntegrationTest(unittest.TestCase):
|
||||
[[[0.1612, 0.4314], [0.1690, 0.4344]], [[0.2086, 0.1396], [0.3014, 0.0903]]]
|
||||
)
|
||||
self.assertTrue(torch.allclose(hidden_states_slice, EXPECTED_HIDDEN_STATES_SLICE, rtol=1e-2))
|
||||
|
||||
def test_inference_diarization(self):
|
||||
model = WavLMForAudioFrameClassification.from_pretrained("microsoft/wavlm-base-plus-sd").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/wavlm-base-plus-sd")
|
||||
input_data = self._load_superb("sd", 4)
|
||||
inputs = processor(input_data["speech"], return_tensors="pt", padding=True, sampling_rate=16_000)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
with torch.no_grad():
|
||||
outputs = model(input_values, attention_mask=attention_mask)
|
||||
# labels is a one-hot array of shape (num_frames, num_speakers)
|
||||
labels = (outputs.logits > 0).long()
|
||||
|
||||
# s3prl logits for the same batch
|
||||
expected_logits = torch.tensor(
|
||||
[
|
||||
[[-5.9566, -8.6554], [-5.7137, -8.9386], [-5.7906, -7.0973], [-5.7829, -5.9999]],
|
||||
[[-5.2086, -7.7878], [-4.8890, -7.9312], [-4.2004, -3.9101], [-5.4480, -4.6932]],
|
||||
[[-4.6105, -6.7178], [-5.1930, -6.1635], [-2.6228, -4.1123], [-2.7646, -3.1576]],
|
||||
[[-4.4477, -7.9206], [-3.9339, -7.3707], [-4.9528, -4.8242], [-3.6921, -2.9687]],
|
||||
],
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertEqual(labels[0, :, 0].sum(), 258)
|
||||
self.assertEqual(labels[0, :, 1].sum(), 647)
|
||||
self.assertTrue(torch.allclose(outputs.logits[:, :4], expected_logits, atol=1e-3))
|
||||
|
||||
def test_inference_speaker_verification(self):
|
||||
model = WavLMForXVector.from_pretrained("microsoft/wavlm-base-plus-sv").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/wavlm-base-plus-sv")
|
||||
input_data = self._load_superb("si", 4)
|
||||
|
||||
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
|
||||
labels = torch.tensor([5, 1, 1, 3], device=torch_device).T
|
||||
|
||||
with torch.no_grad():
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
outputs = model(input_values, attention_mask=attention_mask, labels=labels)
|
||||
embeddings = torch.nn.functional.normalize(outputs.embeddings, dim=-1)
|
||||
|
||||
cosine_sim = torch.nn.CosineSimilarity(dim=-1)
|
||||
# id10002 vs id10002
|
||||
self.assertAlmostEqual(cosine_sim(embeddings[1], embeddings[2]).item(), 0.9787, 3)
|
||||
# id10006 vs id10002
|
||||
self.assertAlmostEqual(cosine_sim(embeddings[0], embeddings[1]).item(), 0.5064, 3)
|
||||
# id10002 vs id10004
|
||||
self.assertAlmostEqual(cosine_sim(embeddings[2], embeddings[3]).item(), 0.4780, 3)
|
||||
|
||||
self.assertAlmostEqual(outputs.loss.item(), 18.4154, 3)
|
||||
|
||||
Reference in New Issue
Block a user