[Speech Examples] Add pytorch speech pretraining (#13877)
* adapt wav2vec2 * add example * add files * adapt * remove bogus file * Apply suggestions from code review * adapt files more * upload changes * del old files * up * up * up * up * up * correct gradient checkpoitning * add readme * finish * finish * up * more fixes * up * up * add demo run to readme * up
This commit is contained in:
committed by
GitHub
parent
3499728dc4
commit
d45fc7da3d
@@ -586,7 +586,8 @@ class HubertUtilsTest(unittest.TestCase):
|
||||
mask_prob = 0.5
|
||||
mask_length = 1
|
||||
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||
mask = torch.from_numpy(mask).to(torch_device)
|
||||
|
||||
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
|
||||
|
||||
@@ -596,7 +597,8 @@ class HubertUtilsTest(unittest.TestCase):
|
||||
mask_prob = 0.5
|
||||
mask_length = 4
|
||||
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||
mask = torch.from_numpy(mask).to(torch_device)
|
||||
|
||||
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
|
||||
for batch_sum in mask.sum(axis=-1):
|
||||
|
||||
@@ -40,7 +40,11 @@ if is_torch_available():
|
||||
Wav2Vec2Model,
|
||||
Wav2Vec2Processor,
|
||||
)
|
||||
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2GumbelVectorQuantizer, _compute_mask_indices
|
||||
from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
||||
Wav2Vec2GumbelVectorQuantizer,
|
||||
_compute_mask_indices,
|
||||
_sample_negative_indices,
|
||||
)
|
||||
|
||||
|
||||
class Wav2Vec2ModelTester:
|
||||
@@ -405,6 +409,12 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"masked_spec_embed",
|
||||
"codevectors",
|
||||
"quantizer.weight_proj.weight",
|
||||
"project_hid.weight",
|
||||
"project_hid.bias",
|
||||
"project_q.weight",
|
||||
"project_q.bias",
|
||||
"feature_projection.projection.weight",
|
||||
"feature_projection.projection.bias",
|
||||
]
|
||||
if param.requires_grad:
|
||||
if any([x in name for x in uniform_init_parms]):
|
||||
@@ -605,6 +615,12 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"masked_spec_embed",
|
||||
"codevectors",
|
||||
"quantizer.weight_proj.weight",
|
||||
"project_hid.weight",
|
||||
"project_hid.bias",
|
||||
"project_q.weight",
|
||||
"project_q.bias",
|
||||
"feature_projection.projection.weight",
|
||||
"feature_projection.projection.bias",
|
||||
]
|
||||
if param.requires_grad:
|
||||
if any([x in name for x in uniform_init_parms]):
|
||||
@@ -640,28 +656,37 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
features_shape = (
|
||||
inputs_dict["input_values"].shape[0],
|
||||
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
|
||||
model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]),
|
||||
)
|
||||
|
||||
mask_time_indices = _compute_mask_indices(
|
||||
features_shape,
|
||||
model.config.mask_time_prob,
|
||||
model.config.mask_time_length,
|
||||
device=inputs_dict["input_values"].device,
|
||||
min_masks=2,
|
||||
).to(torch_device)
|
||||
)
|
||||
sampled_negative_indices = _sample_negative_indices(features_shape, 10, mask_time_indices)
|
||||
|
||||
mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
|
||||
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
|
||||
|
||||
loss = model(
|
||||
inputs_dict["input_values"],
|
||||
attention_mask=inputs_dict["attention_mask"],
|
||||
mask_time_indices=mask_time_indices,
|
||||
sampled_negative_indices=sampled_negative_indices,
|
||||
).loss
|
||||
|
||||
# more losses
|
||||
mask_time_indices[:, : mask_time_indices.shape[-1] // 2] = True
|
||||
|
||||
sampled_negative_indices = _sample_negative_indices(features_shape, 10, mask_time_indices.cpu().numpy())
|
||||
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
|
||||
loss_more_masked = model(
|
||||
inputs_dict["input_values"],
|
||||
attention_mask=inputs_dict["attention_mask"],
|
||||
mask_time_indices=mask_time_indices,
|
||||
sampled_negative_indices=sampled_negative_indices,
|
||||
).loss
|
||||
|
||||
# loss_more_masked has to be bigger or equal loss since more masked inputs have to be predicted
|
||||
@@ -727,7 +752,8 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
mask_prob = 0.5
|
||||
mask_length = 1
|
||||
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||
mask = torch.from_numpy(mask).to(torch_device)
|
||||
|
||||
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
|
||||
|
||||
@@ -737,7 +763,8 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
mask_prob = 0.5
|
||||
mask_length = 4
|
||||
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||
mask = torch.from_numpy(mask).to(torch_device)
|
||||
|
||||
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
|
||||
for batch_sum in mask.sum(axis=-1):
|
||||
@@ -753,8 +780,9 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
attention_mask[:2, sequence_length // 2 :] = 0
|
||||
|
||||
mask = _compute_mask_indices(
|
||||
(batch_size, sequence_length), mask_prob, mask_length, device=torch_device, attention_mask=attention_mask
|
||||
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
|
||||
)
|
||||
mask = torch.from_numpy(mask).to(torch_device)
|
||||
|
||||
for batch_sum in mask.sum(axis=-1):
|
||||
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
|
||||
@@ -785,8 +813,11 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
) # each value in vector consits of same value
|
||||
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
|
||||
|
||||
negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives)
|
||||
|
||||
# sample negative indices
|
||||
sampled_negative_indices = _sample_negative_indices((batch_size, sequence_length), num_negatives, None)
|
||||
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
|
||||
negatives = features.view(-1, hidden_size)[sampled_negative_indices.long().view(-1)]
|
||||
negatives = negatives.view(batch_size, sequence_length, -1, hidden_size).permute(2, 0, 1, 3)
|
||||
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
|
||||
|
||||
# make sure no negatively sampled vector is actually a positive one
|
||||
@@ -796,15 +827,15 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
|
||||
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
|
||||
|
||||
def test_sample_negatives_with_attn_mask(self):
|
||||
def test_sample_negatives_with_mask(self):
|
||||
batch_size = 2
|
||||
sequence_length = 10
|
||||
hidden_size = 4
|
||||
num_negatives = 3
|
||||
|
||||
# second half of last input tensor is padded
|
||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
|
||||
attention_mask[-1, sequence_length // 2 :] = 0
|
||||
mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
|
||||
mask[-1, sequence_length // 2 :] = 0
|
||||
|
||||
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
|
||||
sequence_length, hidden_size
|
||||
@@ -812,9 +843,15 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
|
||||
|
||||
# replace masked feature vectors with -100 to test that those are not sampled
|
||||
features = torch.where(attention_mask[:, :, None].expand(features.shape).bool(), features, -100)
|
||||
features = torch.where(mask[:, :, None].expand(features.shape).bool(), features, -100)
|
||||
|
||||
negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives, attention_mask=attention_mask)
|
||||
# sample negative indices
|
||||
sampled_negative_indices = _sample_negative_indices(
|
||||
(batch_size, sequence_length), num_negatives, mask.cpu().numpy()
|
||||
)
|
||||
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
|
||||
negatives = features.view(-1, hidden_size)[sampled_negative_indices.long().view(-1)]
|
||||
negatives = negatives.view(batch_size, sequence_length, -1, hidden_size).permute(2, 0, 1, 3)
|
||||
|
||||
self.assertTrue((negatives >= 0).all().item())
|
||||
|
||||
@@ -924,16 +961,11 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
]
|
||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||
|
||||
# Wav2Vec2 pretraining seems to be broken. TODO(PVP) - reenable test once pretraining works
|
||||
# correctly
|
||||
@unittest.skipIf(torch_device != "cpu", "cannot make deterministic on GPU")
|
||||
def test_inference_integration(self):
|
||||
return
|
||||
|
||||
model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")
|
||||
model.to(torch_device)
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
"facebook/wav2vec2-base", return_attention_mask=True
|
||||
)
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base")
|
||||
input_speech = self._load_datasamples(2)
|
||||
|
||||
inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
|
||||
@@ -943,19 +975,18 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(4)
|
||||
mask_time_indices = _compute_mask_indices(
|
||||
features_shape,
|
||||
model.config.mask_time_prob,
|
||||
model.config.mask_time_length,
|
||||
device=inputs_dict["input_values"].device,
|
||||
min_masks=2,
|
||||
).to(torch_device)
|
||||
)
|
||||
mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(
|
||||
inputs_dict.input_values.to(torch_device),
|
||||
attention_mask=inputs_dict.attention_mask.to(torch_device),
|
||||
mask_time_indices=mask_time_indices,
|
||||
)
|
||||
|
||||
@@ -965,14 +996,16 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
# retrieve cosine sim of masked features
|
||||
cosine_sim_masked = cosine_sim[mask_time_indices]
|
||||
|
||||
# cosine similarity of model is all > 0.5 as model is
|
||||
# pre-trained on contrastive loss
|
||||
# fmt: off
|
||||
expected_cosine_sim_masked = torch.tensor(
|
||||
[0.7458, 0.7188, 0.6418, 0.3729, 0.3741, 0.3694, 0.3110, 0.2257, 0.4403, 0.5415, 0.3950, 0.3701, 0.8831,
|
||||
0.8613, 0.5229, 0.6696, 0.7206, 0.7877, 0.6758, 0.8746, 0.6596, 0.6282, 0.6178, 0.5839, 0.5926, 0.6651,
|
||||
0.4635, 0.6332, 0.6572, 0.8776, 0.4999, 0.7001, 0.7257, 0.5098, 0.6229, 0.4566, 0.5261, 0.6363, 0.5371,
|
||||
0.6997],
|
||||
device=torch_device,
|
||||
)
|
||||
expected_cosine_sim_masked = torch.tensor([
|
||||
0.8523, 0.5860, 0.6905, 0.5557, 0.7456, 0.5249, 0.6639, 0.7654, 0.7565,
|
||||
0.8167, 0.8222, 0.7960, 0.8034, 0.8166, 0.8310, 0.8263, 0.8274, 0.8258,
|
||||
0.8179, 0.8412, 0.8536, 0.5098, 0.4728, 0.6461, 0.4498, 0.6002, 0.5774,
|
||||
0.6457, 0.7123, 0.5668, 0.6866, 0.4960, 0.6293, 0.7423, 0.7419, 0.7526,
|
||||
0.7768, 0.4898, 0.5393, 0.8183
|
||||
], device=torch_device)
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(cosine_sim_masked, expected_cosine_sim_masked, atol=1e-3))
|
||||
@@ -997,9 +1030,9 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
features_shape,
|
||||
model.config.mask_time_prob,
|
||||
model.config.mask_time_length,
|
||||
device=inputs_dict["input_values"].device,
|
||||
min_masks=2,
|
||||
).to(torch_device)
|
||||
)
|
||||
mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(
|
||||
@@ -1064,28 +1097,36 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
mask_time_indices = _compute_mask_indices(
|
||||
features_shape,
|
||||
model.config.mask_time_prob,
|
||||
model.config.mask_time_length,
|
||||
device=inputs_dict["input_values"].device,
|
||||
min_masks=2,
|
||||
).to(torch_device)
|
||||
)
|
||||
sampled_negative_indices = _sample_negative_indices(
|
||||
mask_time_indices.shape, model.config.num_negatives, mask_time_indices
|
||||
)
|
||||
|
||||
mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
|
||||
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(
|
||||
inputs_dict.input_values.to(torch_device),
|
||||
attention_mask=inputs_dict.attention_mask.to(torch_device),
|
||||
mask_time_indices=mask_time_indices,
|
||||
sampled_negative_indices=sampled_negative_indices,
|
||||
)
|
||||
|
||||
# check diversity loss
|
||||
num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups
|
||||
diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors
|
||||
self.assertTrue(abs(diversity_loss.item() - 0.8859) < 1e-3)
|
||||
self.assertTrue(abs(diversity_loss.item() - 0.9538) < 1e-3)
|
||||
|
||||
# check overall loss (contrastive loss + diversity loss)
|
||||
expected_loss = 62.5170
|
||||
expected_loss = 116.7094
|
||||
|
||||
self.assertTrue(abs(outputs.loss.item() - expected_loss) < 1e-3)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user