Add Fine-Tuning for Wav2Vec2 (#10145)

* add encode labels function to tokenizer

* start adding finetuning

* init dropout

* upload

* correct convert script

* apply changes

* fix second typo

* make first dummy training run

* adapt convert script

* push confg for comparison

* remove conf

* finish training

* adapt data collator

* add research folder

* update according to fairseq feedback

* some minor corrections

* refactor masking indices a bit

* some minor changes

* clean tokenizer

* finish clean-up

* remove previous logic

* update run script

* correct training

* finish changes

* finish model

* correct bug

* fix training a bit more

* add some tests

* finish gradient checkpointing

* finish example

* correct gradient checkpointing

* improve tokenization method

* revert changes in tokenizer

* revert general change

* adapt fine-tuning

* update

* save intermediate test

* Update README.md

* finish finetuning

* delete conversion script

* Update src/transformers/models/wav2vec2/configuration_wav2vec2.py

* Update src/transformers/models/wav2vec2/processing_wav2vec2.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* finish wav2vec2 script

* finish wav2vec2 fine-tuning

* finalize test

* correct test

* adapt tests

* finish

* remove test file

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Patrick von Platen
2021-03-01 12:13:17 +03:00
committed by GitHub
parent 3c733f3208
commit 0234de8418
14 changed files with 932 additions and 84 deletions

View File

@@ -18,7 +18,7 @@
import math
import unittest
from tests.test_modeling_common import floats_tensor, random_attention_mask
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from transformers import is_torch_available
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
@@ -30,6 +30,7 @@ if is_torch_available():
import torch
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Processor
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
class Wav2Vec2ModelTester:
@@ -128,9 +129,7 @@ class Wav2Vec2ModelTester:
)
def create_and_check_batch_inference(self, config, input_values, *args):
# Not sure how to make this test pass at the moment. Batched input yields
# same results as official fairseq implementation, but gives different results
# depending on whether batched input is used or not
# test does not pass for models making use of `group_norm`
# check: https://github.com/pytorch/fairseq/issues/3227
model = Wav2Vec2Model(config=config)
model.to(torch_device)
@@ -155,6 +154,62 @@ class Wav2Vec2ModelTester:
batch_output = batch_outputs[i : i + 1, : output.shape[1]]
self.parent.assertTrue(torch.allclose(output, batch_output, atol=1e-3))
def check_ctc_loss(self, config, input_values, *args):
model = Wav2Vec2ForCTC(config=config)
model.to(torch_device)
# make sure that dropout is disabled
model.eval()
input_values = input_values[:3]
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.bool)
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
labels = ids_tensor((input_values.shape[0], min(max_length_labels) - 1), model.config.vocab_size)
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
attention_mask[i, input_lengths[i] :] = 0.0
model.config.ctc_loss_reduction = "sum"
sum_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss
model.config.ctc_loss_reduction = "mean"
mean_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss
self.parent.assertTrue(abs(labels.shape[0] * labels.shape[1] * mean_loss.item() - sum_loss.item()) < 1e-3)
def check_training(self, config, input_values, *args):
config.ctc_zero_infinity = True
model = Wav2Vec2ForCTC(config=config)
model.to(torch_device)
model.train()
# freeze feature encoder
model.freeze_feature_extractor()
input_values = input_values[:3]
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size)
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
if max_length_labels[i] < labels.shape[-1]:
# it's important that we make sure that target lenghts are at least
# one shorter than logit lenghts to prevent -inf
labels[i, max_length_labels[i] - 1 :] = -100
loss = model(input_values, labels=labels).loss
self.parent.assertFalse(torch.isinf(loss).item())
loss.backward()
def prepare_config_and_inputs_for_common(self):
config, input_values, attention_mask = self.prepare_config_and_inputs()
inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
@@ -165,6 +220,7 @@ class Wav2Vec2ModelTester:
class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
Wav2Vec2ForCTC,
Wav2Vec2Model,
Wav2Vec2ForMaskedLM,
)
@@ -186,6 +242,14 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_ctc_loss_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_loss(*config_and_inputs)
def test_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_training(*config_and_inputs)
# Wav2Vec2 has no inputs_embeds
def test_inputs_embeds(self):
pass
@@ -205,6 +269,46 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
def test_model_common_attributes(self):
pass
def test_retain_grad_hidden_states_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True
config.output_attentions = True
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
model.to(torch_device)
# set layer drop to 0
model.config.layerdrop = 0.0
input_values = inputs_dict["input_values"]
input_lengths = torch.tensor(
[input_values.shape[1] for _ in range(input_values.shape[0])], dtype=torch.long, device=torch_device
)
output_lengths = model._get_feat_extract_output_lengths(input_lengths)
labels = ids_tensor((input_values.shape[0], output_lengths[0] - 2), self.model_tester.vocab_size)
inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["attention_mask"])
inputs_dict["labels"] = labels
outputs = model(**inputs_dict)
output = outputs[0]
# Encoder-/Decoder-only models
hidden_states = outputs.hidden_states[0]
attentions = outputs.attentions[0]
hidden_states.retain_grad()
attentions.retain_grad()
output.flatten()[0].backward(retain_graph=True)
self.assertIsNotNone(hidden_states.grad)
self.assertIsNotNone(attentions.grad)
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -213,7 +317,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
if param.requires_grad:
if "conv.weight" in name:
if "conv.weight" in name or "masked_spec_embed" in name:
self.assertTrue(
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
msg="Parameter {} of model {} seems not properly initialized".format(name, model_class),
@@ -233,7 +337,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
@require_torch
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForCTC) if is_torch_available() else ()
all_model_classes = (Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM) if is_torch_available() else ()
test_pruning = False
test_headmasking = False
test_torchscript = False
@@ -255,6 +359,14 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_batch_inference(*config_and_inputs)
def test_ctc_loss_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_loss(*config_and_inputs)
def test_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_training(*config_and_inputs)
# Wav2Vec2 has no inputs_embeds
def test_inputs_embeds(self):
pass
@@ -274,6 +386,46 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
def test_model_common_attributes(self):
pass
def test_retain_grad_hidden_states_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True
config.output_attentions = True
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
model.to(torch_device)
# set layer drop to 0
model.config.layerdrop = 0.0
input_values = inputs_dict["input_values"]
input_lengths = torch.tensor(
[input_values.shape[1] for _ in range(input_values.shape[0])], dtype=torch.long, device=torch_device
)
output_lengths = model._get_feat_extract_output_lengths(input_lengths)
labels = ids_tensor((input_values.shape[0], output_lengths[0] - 2), self.model_tester.vocab_size)
inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["attention_mask"])
inputs_dict["labels"] = labels
outputs = model(**inputs_dict)
output = outputs[0]
# Encoder-/Decoder-only models
hidden_states = outputs.hidden_states[0]
attentions = outputs.attentions[0]
hidden_states.retain_grad()
attentions.retain_grad()
output.flatten()[0].backward(retain_graph=True)
self.assertIsNotNone(hidden_states.grad)
self.assertIsNotNone(attentions.grad)
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -282,7 +434,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
if param.requires_grad:
if "conv.weight" in name:
if "conv.weight" in name or "masked_spec_embed" in name:
self.assertTrue(
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
msg="Parameter {} of model {} seems not properly initialized".format(name, model_class),
@@ -300,6 +452,59 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(model)
@require_torch
class Wav2Vec2UtilsTest(unittest.TestCase):
def test_compute_mask_indices(self):
batch_size = 4
sequence_length = 60
mask_prob = 0.5
mask_length = 1
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
attention_mask = torch.ones((batch_size, sequence_length), device=torch_device, dtype=torch.long)
attention_mask[:, -sequence_length // 2 :] = 0
mask = _compute_mask_indices(
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
)
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length // 2 for _ in range(batch_size)])
def test_compute_mask_indices_overlap(self):
batch_size = 4
sequence_length = 60
mask_prob = 0.5
mask_length = 4
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
# because of overlap there is a range of possible masks
for batch_sum in mask.sum(axis=-1):
self.assertIn(
int(batch_sum),
list(range(int(mask_prob // mask_length * sequence_length), int(mask_prob * sequence_length))),
)
attention_mask = torch.ones((batch_size, sequence_length), device=torch_device, dtype=torch.long)
attention_mask[:, -sequence_length // 2 :] = 0
mask = _compute_mask_indices(
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
)
# because of overlap there is a range of possible masks
for batch_sum in mask.sum(axis=-1):
self.assertIn(
int(batch_sum),
list(
range(int(mask_prob // mask_length * sequence_length // 2), int(mask_prob * sequence_length // 2))
),
)
@require_torch
@slow
@require_datasets