[MMS] Scaling Speech Technology to 1,000+ Languages | Add attention adapter to Wav2Vec2 (#23813)

* add fine-tuned with adapter layer

* Add set_target_lang to tokenizer

* Implement load adapter

* add tests

* make style

* Apply suggestions from code review

* Update src/transformers/models/wav2vec2/tokenization_wav2vec2.py

* make fix-copies

* Apply suggestions from code review

* make fix-copies

* make style again

* mkae style again

* fix doc string

* Update tests/models/wav2vec2/test_tokenization_wav2vec2.py

* Apply suggestions from code review

* fix

* Correct wav2vec2 adapter

* mkae style

* Update src/transformers/models/wav2vec2/modeling_wav2vec2.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* add more nice docs

* finish

* finish

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Apply suggestions from code review

* all finish

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2023-06-02 10:30:24 +01:00
committed by GitHub
parent f49a3453ca
commit 5dfd407b37
24 changed files with 823 additions and 33 deletions

View File

@@ -54,6 +54,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from safetensors.torch import save_file as safe_save_file
from transformers import (
Wav2Vec2FeatureExtractor,
@@ -67,6 +68,8 @@ if is_torch_available():
Wav2Vec2Processor,
)
from transformers.models.wav2vec2.modeling_wav2vec2 import (
WAV2VEC2_ADAPTER_PT_FILE,
WAV2VEC2_ADAPTER_SAFE_FILE,
Wav2Vec2GumbelVectorQuantizer,
_compute_mask_indices,
_sample_negative_indices,
@@ -290,6 +293,17 @@ class Wav2Vec2ModelTester:
(self.batch_size, self.adapter_output_seq_length, config.output_hidden_size),
)
def create_and_check_model_with_attn_adapter(self, config, input_values, attention_mask):
config.adapter_attn_dim = 16
model = Wav2Vec2ForCTC(config=config)
self.parent.assertIsNotNone(model._adapters)
model.to(torch_device)
model.eval()
result = model(input_values, attention_mask=attention_mask)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.output_seq_length, self.vocab_size))
def create_and_check_batch_inference(self, config, input_values, *args):
# test does not pass for models making use of `group_norm`
# check: https://github.com/pytorch/fairseq/issues/3227
@@ -844,6 +858,10 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs)
def test_model_with_attn_adapter(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_with_attn_adapter(*config_and_inputs)
def test_batched_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_batch_inference(*config_and_inputs)
@@ -1098,6 +1116,85 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
def test_feed_forward_chunking(self):
pass
def test_load_attn_adapter(self):
processor = Wav2Vec2Processor.from_pretrained(
"hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True
)
def get_logits(model, input_features):
model = model.to(torch_device)
batch = processor(
input_features,
padding=True,
sampling_rate=processor.feature_extractor.sampling_rate,
return_tensors="pt",
)
with torch.no_grad():
logits = model(
input_values=batch["input_values"].to(torch_device),
attention_mask=batch["attention_mask"].to(torch_device),
).logits
return logits
input_features = [np.random.random(16_000 * s) for s in [1, 3, 2, 6]]
model = Wav2Vec2ForCTC.from_pretrained("hf-internal-testing/tiny-random-wav2vec2", adapter_attn_dim=16)
with tempfile.TemporaryDirectory() as tempdir:
model.save_pretrained(tempdir)
model = Wav2Vec2ForCTC.from_pretrained(tempdir)
logits = get_logits(model, input_features)
adapter_weights = model._adapters
# save safe weights
safe_filepath = os.path.join(tempdir, WAV2VEC2_ADAPTER_SAFE_FILE.format("eng"))
safe_save_file(adapter_weights, safe_filepath, metadata={"format": "pt"})
model.load_adapter("eng")
model.load_adapter("eng", use_safetensors=True)
with self.assertRaises(OSError):
model.load_adapter("eng", use_safetensors=False)
with self.assertRaises(Exception):
model.load_adapter("ita", use_safetensors=True)
logits_2 = get_logits(model, input_features)
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
with tempfile.TemporaryDirectory() as tempdir:
model.save_pretrained(tempdir)
model = Wav2Vec2ForCTC.from_pretrained(tempdir)
logits = get_logits(model, input_features)
adapter_weights = model._adapters
# save pt weights
pt_filepath = os.path.join(tempdir, WAV2VEC2_ADAPTER_PT_FILE.format("eng"))
torch.save(adapter_weights, pt_filepath)
model.load_adapter("eng")
model.load_adapter("eng", use_safetensors=False)
with self.assertRaises(OSError):
model.load_adapter("eng", use_safetensors=True)
logits_2 = get_logits(model, input_features)
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
model = Wav2Vec2ForCTC.from_pretrained("hf-internal-testing/tiny-random-wav2vec2-adapter")
logits = get_logits(model, input_features)
model.load_adapter("eng")
model.load_adapter("eng", use_safetensors=False)
model.load_adapter("eng", use_safetensors=True)
logits_2 = get_logits(model, input_features)
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
@slow
def test_model_from_pretrained(self):
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
@@ -1768,3 +1865,45 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
# TODO: update the tolerance after the CI moves to torch 1.10
self.assertAlmostEqual(outputs.loss.item(), 17.7963, 2)
@require_torchaudio
def test_inference_mms_1b_all(self):
model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-1b-all").to(torch_device)
processor = Wav2Vec2Processor.from_pretrained("facebook/mms-1b-all")
LANG_MAP = {"it": "ita", "es": "spa", "fr": "fra", "en": "eng"}
def run_model(lang):
ds = load_dataset("common_voice", lang, split="test", streaming=True)
sample = next(iter(ds))
wav2vec2_lang = LANG_MAP[lang]
model.load_adapter(wav2vec2_lang)
processor.tokenizer.set_target_lang(wav2vec2_lang)
resampled_audio = torchaudio.functional.resample(
torch.tensor(sample["audio"]["array"]), 48_000, 16_000
).numpy()
inputs = processor(resampled_audio, sampling_rate=16_000, return_tensors="pt")
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
with torch.no_grad():
outputs = model(input_values, attention_mask=attention_mask).logits
ids = torch.argmax(outputs, dim=-1)[0]
transcription = processor.decode(ids)
return transcription
TRANSCRIPTIONS = {
"it": "mi hanno fatto un'offerta che non potevo proprio rifiutare",
"es": "bien y qué regalo vas a abrir primero",
"fr": "un vrai travail intéressant va enfin être mené sur ce sujet",
"en": "twas the time of day and olof spen slept during the summer",
}
for lang in LANG_MAP.keys():
assert run_model(lang) == TRANSCRIPTIONS[lang]