Wav2Vec2 Pretraining (#11306)
* Working quantizer forward * Working quantizer forward * Clean up unused model parts, test reproducibility * Working quantizer forward * Clean up unused model parts, test reproducibility * Remove custom outputs from the shared ones * correct conversion * correct bug * add first pretrain script * save intermediate * static shapes * save intermediate * finish first pretrain script version * more refactor * remove wanddb * refactor more * improve test * correct perplexity compute bug * finish model implementation * add to docs * finish docs * finish pretraining script * finish pretraining script * remove wandb * finish PR for merge * finish config * finish * make deepspeed work * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * apply suggestions * fix flaky test Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com> Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -29,8 +29,16 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Processor
|
||||
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
|
||||
from transformers import (
|
||||
Wav2Vec2Config,
|
||||
Wav2Vec2FeatureExtractor,
|
||||
Wav2Vec2ForCTC,
|
||||
Wav2Vec2ForMaskedLM,
|
||||
Wav2Vec2ForPreTraining,
|
||||
Wav2Vec2Model,
|
||||
Wav2Vec2Processor,
|
||||
)
|
||||
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2GumbelVectorQuantizer, _compute_mask_indices
|
||||
|
||||
|
||||
class Wav2Vec2ModelTester:
|
||||
@@ -219,13 +227,7 @@ class Wav2Vec2ModelTester:
|
||||
@require_torch
|
||||
class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(
|
||||
Wav2Vec2ForCTC,
|
||||
Wav2Vec2Model,
|
||||
Wav2Vec2ForMaskedLM,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
(Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining) if is_torch_available() else ()
|
||||
)
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
@@ -316,8 +318,14 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
uniform_init_parms = [
|
||||
"conv.weight",
|
||||
"masked_spec_embed",
|
||||
"codevectors",
|
||||
"quantizer.weight_proj.weight",
|
||||
]
|
||||
if param.requires_grad:
|
||||
if "conv.weight" in name or "masked_spec_embed" in name:
|
||||
if any([x in name for x in uniform_init_parms]):
|
||||
self.assertTrue(
|
||||
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
@@ -333,10 +341,14 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def _mock_init_weights(self, module):
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data.fill_(3)
|
||||
if hasattr(module, "weight_g") and module.weight is not None:
|
||||
if hasattr(module, "weight_g") and module.weight_g is not None:
|
||||
module.weight_g.data.fill_(3)
|
||||
if hasattr(module, "weight_v") and module.weight_v is not None:
|
||||
module.weight_v.data.fill_(3)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.fill_(3)
|
||||
if hasattr(module, "codevectors") and module.codevectors is not None:
|
||||
module.codevectors.data.fill_(3)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
@@ -346,7 +358,9 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM) if is_torch_available() else ()
|
||||
all_model_classes = (
|
||||
(Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining) if is_torch_available() else ()
|
||||
)
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_torchscript = False
|
||||
@@ -442,8 +456,14 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
uniform_init_parms = [
|
||||
"conv.weight",
|
||||
"masked_spec_embed",
|
||||
"codevectors",
|
||||
"quantizer.weight_proj.weight",
|
||||
]
|
||||
if param.requires_grad:
|
||||
if "conv.weight" in name or "masked_spec_embed" in name:
|
||||
if any([x in name for x in uniform_init_parms]):
|
||||
self.assertTrue(
|
||||
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
@@ -459,10 +479,47 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def _mock_init_weights(self, module):
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data.fill_(3)
|
||||
if hasattr(module, "weight_g") and module.weight is not None:
|
||||
if hasattr(module, "weight_g") and module.weight_g is not None:
|
||||
module.weight_g.data.fill_(3)
|
||||
if hasattr(module, "weight_v") and module.weight_v is not None:
|
||||
module.weight_v.data.fill_(3)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.fill_(3)
|
||||
if hasattr(module, "codevectors") and module.codevectors is not None:
|
||||
module.codevectors.data.fill_(3)
|
||||
|
||||
def test_model_for_pretraining(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = Wav2Vec2ForPreTraining(config).to(torch_device)
|
||||
|
||||
features_shape = (
|
||||
inputs_dict["input_values"].shape[0],
|
||||
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
|
||||
)
|
||||
|
||||
mask_time_indices = _compute_mask_indices(
|
||||
features_shape,
|
||||
model.config.mask_time_prob,
|
||||
model.config.mask_time_length,
|
||||
device=inputs_dict["input_values"].device,
|
||||
min_masks=2,
|
||||
).to(torch_device)
|
||||
|
||||
loss = model(
|
||||
inputs_dict["input_values"],
|
||||
attention_mask=inputs_dict["attention_mask"],
|
||||
mask_time_indices=mask_time_indices,
|
||||
).loss
|
||||
|
||||
mask_time_indices[:, : mask_time_indices.shape[-1] // 2] = True
|
||||
loss_more_masked = model(
|
||||
inputs_dict["input_values"],
|
||||
attention_mask=inputs_dict["attention_mask"],
|
||||
mask_time_indices=mask_time_indices,
|
||||
).loss
|
||||
|
||||
# loss_more_masked has to be bigger or equal loss since more masked inputs have to be predicted
|
||||
self.assertTrue(loss.detach().item() <= loss_more_masked.detach().item())
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
@@ -484,24 +541,56 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
|
||||
def test_compute_mask_indices_overlap(self):
|
||||
batch_size = 4
|
||||
sequence_length = 60
|
||||
sequence_length = 80
|
||||
mask_prob = 0.5
|
||||
mask_length = 4
|
||||
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
|
||||
|
||||
# because of overlap there is a range of possible masks
|
||||
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
|
||||
for batch_sum in mask.sum(axis=-1):
|
||||
self.assertIn(
|
||||
int(batch_sum),
|
||||
list(range(int(mask_prob // mask_length * sequence_length), int(mask_prob * sequence_length))),
|
||||
)
|
||||
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
|
||||
|
||||
def test_compute_perplexity(self):
|
||||
probs = torch.arange(100, device=torch_device).reshape(2, 5, 10) / 100
|
||||
|
||||
ppl = Wav2Vec2GumbelVectorQuantizer._compute_perplexity(probs)
|
||||
self.assertTrue(abs(ppl.item() - 141.4291) < 1e-3)
|
||||
|
||||
# mask half of the input
|
||||
mask = torch.ones((2,), device=torch_device, dtype=torch.bool)
|
||||
mask[0] = 0
|
||||
|
||||
ppl = Wav2Vec2GumbelVectorQuantizer._compute_perplexity(probs, mask)
|
||||
self.assertTrue(abs(ppl.item() - 58.6757) < 1e-3)
|
||||
|
||||
def test_sample_negatives(self):
|
||||
batch_size = 2
|
||||
sequence_length = 10
|
||||
hidden_size = 4
|
||||
num_negatives = 3
|
||||
|
||||
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
|
||||
sequence_length, hidden_size
|
||||
) # each value in vector consits of same value
|
||||
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
|
||||
|
||||
negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives)
|
||||
|
||||
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
|
||||
|
||||
# make sure no negatively sampled vector is actually a positive one
|
||||
for negative in negatives:
|
||||
self.assertTrue(((negative - features) == 0).sum() == 0.0)
|
||||
|
||||
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
|
||||
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
|
||||
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
@require_datasets
|
||||
@require_soundfile
|
||||
@slow
|
||||
class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
def _load_datasamples(self, num_samples):
|
||||
from datasets import load_dataset
|
||||
@@ -586,3 +675,160 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
"his instant panic was followed by a small sharp blow high on his chest",
|
||||
]
|
||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||
|
||||
def test_inference_integration(self):
|
||||
model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
|
||||
model.to(torch_device)
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
"patrickvonplaten/wav2vec2-base", return_attention_mask=True
|
||||
)
|
||||
input_speech = self._load_datasamples(2)
|
||||
|
||||
inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
|
||||
|
||||
features_shape = (
|
||||
inputs_dict["input_values"].shape[0],
|
||||
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
mask_time_indices = _compute_mask_indices(
|
||||
features_shape,
|
||||
model.config.mask_time_prob,
|
||||
model.config.mask_time_length,
|
||||
device=inputs_dict["input_values"].device,
|
||||
min_masks=2,
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(
|
||||
inputs_dict.input_values.to(torch_device),
|
||||
attention_mask=inputs_dict.attention_mask.to(torch_device),
|
||||
mask_time_indices=mask_time_indices,
|
||||
)
|
||||
|
||||
# compute cosine similarity
|
||||
cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
|
||||
|
||||
# retrieve cosine sim of masked features
|
||||
cosine_sim_masked = cosine_sim[mask_time_indices]
|
||||
|
||||
# fmt: off
|
||||
expected_cosine_sim_masked = torch.tensor(
|
||||
[0.7458, 0.7188, 0.6418, 0.3729, 0.3741, 0.3694, 0.3110, 0.2257, 0.4403, 0.5415, 0.3950, 0.3701, 0.8831, 0.8613, 0.5229, 0.6696, 0.7206, 0.7877, 0.6758, 0.8746, 0.6596, 0.6282, 0.6178, 0.5839, 0.5926, 0.6651, 0.4635, 0.6332, 0.6572, 0.8776, 0.4999, 0.7001, 0.7257, 0.5098, 0.6229, 0.4566, 0.5261, 0.6363, 0.5371, 0.6997],
|
||||
device=torch_device,
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(cosine_sim_masked, expected_cosine_sim_masked, atol=1e-3))
|
||||
|
||||
def test_inference_pretrained(self):
|
||||
model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
|
||||
model.to(torch_device)
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
"patrickvonplaten/wav2vec2-base", return_attention_mask=True
|
||||
)
|
||||
input_speech = self._load_datasamples(2)
|
||||
|
||||
inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
|
||||
|
||||
features_shape = (
|
||||
inputs_dict["input_values"].shape[0],
|
||||
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
mask_time_indices = _compute_mask_indices(
|
||||
features_shape,
|
||||
model.config.mask_time_prob,
|
||||
model.config.mask_time_length,
|
||||
device=inputs_dict["input_values"].device,
|
||||
min_masks=2,
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(
|
||||
inputs_dict.input_values.to(torch_device),
|
||||
attention_mask=inputs_dict.attention_mask.to(torch_device),
|
||||
mask_time_indices=mask_time_indices,
|
||||
)
|
||||
|
||||
# compute cosine similarity
|
||||
cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
|
||||
|
||||
# retrieve cosine sim of masked features
|
||||
cosine_sim_masked = cosine_sim[mask_time_indices]
|
||||
|
||||
# ... now compare to randomly initialized model
|
||||
|
||||
config = Wav2Vec2Config.from_pretrained("patrickvonplaten/wav2vec2-base")
|
||||
model_rand = Wav2Vec2ForPreTraining(config).to(torch_device).eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs_rand = model_rand(
|
||||
inputs_dict.input_values.to(torch_device),
|
||||
attention_mask=inputs_dict.attention_mask.to(torch_device),
|
||||
mask_time_indices=mask_time_indices,
|
||||
)
|
||||
|
||||
# compute cosine similarity
|
||||
cosine_sim_rand = torch.cosine_similarity(
|
||||
outputs_rand.projected_states, outputs_rand.projected_quantized_states, dim=-1
|
||||
)
|
||||
|
||||
# retrieve cosine sim of masked features
|
||||
cosine_sim_masked_rand = cosine_sim_rand[mask_time_indices]
|
||||
|
||||
# a pretrained wav2vec2 model has learned to predict the quantized latent states
|
||||
# => the cosine similarity between quantized states and predicted states > 0.5
|
||||
# a random wav2vec2 model has not learned to predict the quantized latent states
|
||||
# => the cosine similarity between quantized states and predicted states is very likely < 0.1
|
||||
self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0)
|
||||
|
||||
def test_loss_pretraining(self):
|
||||
model = Wav2Vec2ForPreTraining.from_pretrained(
|
||||
"patrickvonplaten/wav2vec2-base",
|
||||
attention_dropout=0.0,
|
||||
feat_proj_dropout=0.0,
|
||||
hidden_dropout=0.0,
|
||||
layerdrop=0.0,
|
||||
)
|
||||
model.to(torch_device).train()
|
||||
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
"patrickvonplaten/wav2vec2-base", return_attention_mask=True
|
||||
)
|
||||
input_speech = self._load_datasamples(2)
|
||||
|
||||
inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
|
||||
|
||||
features_shape = (
|
||||
inputs_dict["input_values"].shape[0],
|
||||
model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]),
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
mask_time_indices = _compute_mask_indices(
|
||||
features_shape,
|
||||
model.config.mask_time_prob,
|
||||
model.config.mask_time_length,
|
||||
device=inputs_dict["input_values"].device,
|
||||
min_masks=2,
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(
|
||||
inputs_dict.input_values.to(torch_device),
|
||||
attention_mask=inputs_dict.attention_mask.to(torch_device),
|
||||
mask_time_indices=mask_time_indices,
|
||||
)
|
||||
|
||||
# check diversity loss
|
||||
num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups
|
||||
diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors
|
||||
self.assertTrue(abs(diversity_loss.item() - 0.8859) < 1e-3)
|
||||
|
||||
# check overall loss (contrastive loss + diversity loss)
|
||||
expected_loss = 62.5170 if model.device.type == "cpu" else 50.3612
|
||||
|
||||
self.assertTrue(abs(outputs.loss.item() - expected_loss) < 1e-3)
|
||||
|
||||
Reference in New Issue
Block a user