[Speech Examples] Add pytorch speech pretraining (#13877)

* adapt wav2vec2

* add example

* add files

* adapt

* remove bogus file

* Apply suggestions from code review

* adapt files more

* upload changes

* del old files

* up

* up

* up

* up

* up

* correct gradient checkpoitning

* add readme

* finish

* finish

* up

* more fixes

* up

* up

* add demo run to readme

* up
This commit is contained in:
Patrick von Platen
2021-10-12 00:46:32 +02:00
committed by GitHub
parent 3499728dc4
commit d45fc7da3d
9 changed files with 1196 additions and 183 deletions

View File

@@ -586,7 +586,8 @@ class HubertUtilsTest(unittest.TestCase):
mask_prob = 0.5
mask_length = 1
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
mask = torch.from_numpy(mask).to(torch_device)
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
@@ -596,7 +597,8 @@ class HubertUtilsTest(unittest.TestCase):
mask_prob = 0.5
mask_length = 4
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
mask = torch.from_numpy(mask).to(torch_device)
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
for batch_sum in mask.sum(axis=-1):

View File

@@ -40,7 +40,11 @@ if is_torch_available():
Wav2Vec2Model,
Wav2Vec2Processor,
)
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2GumbelVectorQuantizer, _compute_mask_indices
from transformers.models.wav2vec2.modeling_wav2vec2 import (
Wav2Vec2GumbelVectorQuantizer,
_compute_mask_indices,
_sample_negative_indices,
)
class Wav2Vec2ModelTester:
@@ -405,6 +409,12 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
"masked_spec_embed",
"codevectors",
"quantizer.weight_proj.weight",
"project_hid.weight",
"project_hid.bias",
"project_q.weight",
"project_q.bias",
"feature_projection.projection.weight",
"feature_projection.projection.bias",
]
if param.requires_grad:
if any([x in name for x in uniform_init_parms]):
@@ -605,6 +615,12 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
"masked_spec_embed",
"codevectors",
"quantizer.weight_proj.weight",
"project_hid.weight",
"project_hid.bias",
"project_q.weight",
"project_q.bias",
"feature_projection.projection.weight",
"feature_projection.projection.bias",
]
if param.requires_grad:
if any([x in name for x in uniform_init_parms]):
@@ -640,28 +656,37 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
features_shape = (
inputs_dict["input_values"].shape[0],
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]),
)
mask_time_indices = _compute_mask_indices(
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
device=inputs_dict["input_values"].device,
min_masks=2,
).to(torch_device)
)
sampled_negative_indices = _sample_negative_indices(features_shape, 10, mask_time_indices)
mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
loss = model(
inputs_dict["input_values"],
attention_mask=inputs_dict["attention_mask"],
mask_time_indices=mask_time_indices,
sampled_negative_indices=sampled_negative_indices,
).loss
# more losses
mask_time_indices[:, : mask_time_indices.shape[-1] // 2] = True
sampled_negative_indices = _sample_negative_indices(features_shape, 10, mask_time_indices.cpu().numpy())
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
loss_more_masked = model(
inputs_dict["input_values"],
attention_mask=inputs_dict["attention_mask"],
mask_time_indices=mask_time_indices,
sampled_negative_indices=sampled_negative_indices,
).loss
# loss_more_masked has to be bigger or equal loss since more masked inputs have to be predicted
@@ -727,7 +752,8 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
mask_prob = 0.5
mask_length = 1
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
mask = torch.from_numpy(mask).to(torch_device)
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
@@ -737,7 +763,8 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
mask_prob = 0.5
mask_length = 4
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
mask = torch.from_numpy(mask).to(torch_device)
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
for batch_sum in mask.sum(axis=-1):
@@ -753,8 +780,9 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
attention_mask[:2, sequence_length // 2 :] = 0
mask = _compute_mask_indices(
(batch_size, sequence_length), mask_prob, mask_length, device=torch_device, attention_mask=attention_mask
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
)
mask = torch.from_numpy(mask).to(torch_device)
for batch_sum in mask.sum(axis=-1):
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
@@ -785,8 +813,11 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
) # each value in vector consits of same value
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives)
# sample negative indices
sampled_negative_indices = _sample_negative_indices((batch_size, sequence_length), num_negatives, None)
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
negatives = features.view(-1, hidden_size)[sampled_negative_indices.long().view(-1)]
negatives = negatives.view(batch_size, sequence_length, -1, hidden_size).permute(2, 0, 1, 3)
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
# make sure no negatively sampled vector is actually a positive one
@@ -796,15 +827,15 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
def test_sample_negatives_with_attn_mask(self):
def test_sample_negatives_with_mask(self):
batch_size = 2
sequence_length = 10
hidden_size = 4
num_negatives = 3
# second half of last input tensor is padded
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
attention_mask[-1, sequence_length // 2 :] = 0
mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
mask[-1, sequence_length // 2 :] = 0
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
sequence_length, hidden_size
@@ -812,9 +843,15 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
# replace masked feature vectors with -100 to test that those are not sampled
features = torch.where(attention_mask[:, :, None].expand(features.shape).bool(), features, -100)
features = torch.where(mask[:, :, None].expand(features.shape).bool(), features, -100)
negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives, attention_mask=attention_mask)
# sample negative indices
sampled_negative_indices = _sample_negative_indices(
(batch_size, sequence_length), num_negatives, mask.cpu().numpy()
)
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
negatives = features.view(-1, hidden_size)[sampled_negative_indices.long().view(-1)]
negatives = negatives.view(batch_size, sequence_length, -1, hidden_size).permute(2, 0, 1, 3)
self.assertTrue((negatives >= 0).all().item())
@@ -924,16 +961,11 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
# Wav2Vec2 pretraining seems to be broken. TODO(PVP) - reenable test once pretraining works
# correctly
@unittest.skipIf(torch_device != "cpu", "cannot make deterministic on GPU")
def test_inference_integration(self):
return
model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")
model.to(torch_device)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"facebook/wav2vec2-base", return_attention_mask=True
)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base")
input_speech = self._load_datasamples(2)
inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
@@ -943,19 +975,18 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
)
torch.manual_seed(0)
np.random.seed(4)
mask_time_indices = _compute_mask_indices(
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
device=inputs_dict["input_values"].device,
min_masks=2,
).to(torch_device)
)
mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
with torch.no_grad():
outputs = model(
inputs_dict.input_values.to(torch_device),
attention_mask=inputs_dict.attention_mask.to(torch_device),
mask_time_indices=mask_time_indices,
)
@@ -965,14 +996,16 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
# retrieve cosine sim of masked features
cosine_sim_masked = cosine_sim[mask_time_indices]
# cosine similarity of model is all > 0.5 as model is
# pre-trained on contrastive loss
# fmt: off
expected_cosine_sim_masked = torch.tensor(
[0.7458, 0.7188, 0.6418, 0.3729, 0.3741, 0.3694, 0.3110, 0.2257, 0.4403, 0.5415, 0.3950, 0.3701, 0.8831,
0.8613, 0.5229, 0.6696, 0.7206, 0.7877, 0.6758, 0.8746, 0.6596, 0.6282, 0.6178, 0.5839, 0.5926, 0.6651,
0.4635, 0.6332, 0.6572, 0.8776, 0.4999, 0.7001, 0.7257, 0.5098, 0.6229, 0.4566, 0.5261, 0.6363, 0.5371,
0.6997],
device=torch_device,
)
expected_cosine_sim_masked = torch.tensor([
0.8523, 0.5860, 0.6905, 0.5557, 0.7456, 0.5249, 0.6639, 0.7654, 0.7565,
0.8167, 0.8222, 0.7960, 0.8034, 0.8166, 0.8310, 0.8263, 0.8274, 0.8258,
0.8179, 0.8412, 0.8536, 0.5098, 0.4728, 0.6461, 0.4498, 0.6002, 0.5774,
0.6457, 0.7123, 0.5668, 0.6866, 0.4960, 0.6293, 0.7423, 0.7419, 0.7526,
0.7768, 0.4898, 0.5393, 0.8183
], device=torch_device)
# fmt: on
self.assertTrue(torch.allclose(cosine_sim_masked, expected_cosine_sim_masked, atol=1e-3))
@@ -997,9 +1030,9 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
device=inputs_dict["input_values"].device,
min_masks=2,
).to(torch_device)
)
mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
with torch.no_grad():
outputs = model(
@@ -1064,28 +1097,36 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
)
torch.manual_seed(0)
np.random.seed(0)
mask_time_indices = _compute_mask_indices(
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
device=inputs_dict["input_values"].device,
min_masks=2,
).to(torch_device)
)
sampled_negative_indices = _sample_negative_indices(
mask_time_indices.shape, model.config.num_negatives, mask_time_indices
)
mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
with torch.no_grad():
outputs = model(
inputs_dict.input_values.to(torch_device),
attention_mask=inputs_dict.attention_mask.to(torch_device),
mask_time_indices=mask_time_indices,
sampled_negative_indices=sampled_negative_indices,
)
# check diversity loss
num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups
diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors
self.assertTrue(abs(diversity_loss.item() - 0.8859) < 1e-3)
self.assertTrue(abs(diversity_loss.item() - 0.9538) < 1e-3)
# check overall loss (contrastive loss + diversity loss)
expected_loss = 62.5170
expected_loss = 116.7094
self.assertTrue(abs(outputs.loss.item() - expected_loss) < 1e-3)