[MMS] Scaling Speech Technology to 1,000+ Languages | Add attention adapter to Wav2Vec2 (#23813)
* add fine-tuned with adapter layer * Add set_target_lang to tokenizer * Implement load adapter * add tests * make style * Apply suggestions from code review * Update src/transformers/models/wav2vec2/tokenization_wav2vec2.py * make fix-copies * Apply suggestions from code review * make fix-copies * make style again * mkae style again * fix doc string * Update tests/models/wav2vec2/test_tokenization_wav2vec2.py * Apply suggestions from code review * fix * Correct wav2vec2 adapter * mkae style * Update src/transformers/models/wav2vec2/modeling_wav2vec2.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * add more nice docs * finish * finish * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Apply suggestions from code review * all finish --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
f49a3453ca
commit
5dfd407b37
@@ -54,6 +54,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from safetensors.torch import save_file as safe_save_file
|
||||
|
||||
from transformers import (
|
||||
Wav2Vec2FeatureExtractor,
|
||||
@@ -67,6 +68,8 @@ if is_torch_available():
|
||||
Wav2Vec2Processor,
|
||||
)
|
||||
from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
||||
WAV2VEC2_ADAPTER_PT_FILE,
|
||||
WAV2VEC2_ADAPTER_SAFE_FILE,
|
||||
Wav2Vec2GumbelVectorQuantizer,
|
||||
_compute_mask_indices,
|
||||
_sample_negative_indices,
|
||||
@@ -290,6 +293,17 @@ class Wav2Vec2ModelTester:
|
||||
(self.batch_size, self.adapter_output_seq_length, config.output_hidden_size),
|
||||
)
|
||||
|
||||
def create_and_check_model_with_attn_adapter(self, config, input_values, attention_mask):
|
||||
config.adapter_attn_dim = 16
|
||||
model = Wav2Vec2ForCTC(config=config)
|
||||
|
||||
self.parent.assertIsNotNone(model._adapters)
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_values, attention_mask=attention_mask)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.output_seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_batch_inference(self, config, input_values, *args):
|
||||
# test does not pass for models making use of `group_norm`
|
||||
# check: https://github.com/pytorch/fairseq/issues/3227
|
||||
@@ -844,6 +858,10 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs)
|
||||
|
||||
def test_model_with_attn_adapter(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_with_attn_adapter(*config_and_inputs)
|
||||
|
||||
def test_batched_inference(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_batch_inference(*config_and_inputs)
|
||||
@@ -1098,6 +1116,85 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
def test_load_attn_adapter(self):
|
||||
processor = Wav2Vec2Processor.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True
|
||||
)
|
||||
|
||||
def get_logits(model, input_features):
|
||||
model = model.to(torch_device)
|
||||
batch = processor(
|
||||
input_features,
|
||||
padding=True,
|
||||
sampling_rate=processor.feature_extractor.sampling_rate,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(
|
||||
input_values=batch["input_values"].to(torch_device),
|
||||
attention_mask=batch["attention_mask"].to(torch_device),
|
||||
).logits
|
||||
return logits
|
||||
|
||||
input_features = [np.random.random(16_000 * s) for s in [1, 3, 2, 6]]
|
||||
|
||||
model = Wav2Vec2ForCTC.from_pretrained("hf-internal-testing/tiny-random-wav2vec2", adapter_attn_dim=16)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
model.save_pretrained(tempdir)
|
||||
model = Wav2Vec2ForCTC.from_pretrained(tempdir)
|
||||
|
||||
logits = get_logits(model, input_features)
|
||||
adapter_weights = model._adapters
|
||||
|
||||
# save safe weights
|
||||
safe_filepath = os.path.join(tempdir, WAV2VEC2_ADAPTER_SAFE_FILE.format("eng"))
|
||||
safe_save_file(adapter_weights, safe_filepath, metadata={"format": "pt"})
|
||||
|
||||
model.load_adapter("eng")
|
||||
model.load_adapter("eng", use_safetensors=True)
|
||||
|
||||
with self.assertRaises(OSError):
|
||||
model.load_adapter("eng", use_safetensors=False)
|
||||
with self.assertRaises(Exception):
|
||||
model.load_adapter("ita", use_safetensors=True)
|
||||
logits_2 = get_logits(model, input_features)
|
||||
|
||||
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
model.save_pretrained(tempdir)
|
||||
model = Wav2Vec2ForCTC.from_pretrained(tempdir)
|
||||
|
||||
logits = get_logits(model, input_features)
|
||||
adapter_weights = model._adapters
|
||||
|
||||
# save pt weights
|
||||
pt_filepath = os.path.join(tempdir, WAV2VEC2_ADAPTER_PT_FILE.format("eng"))
|
||||
torch.save(adapter_weights, pt_filepath)
|
||||
|
||||
model.load_adapter("eng")
|
||||
model.load_adapter("eng", use_safetensors=False)
|
||||
|
||||
with self.assertRaises(OSError):
|
||||
model.load_adapter("eng", use_safetensors=True)
|
||||
|
||||
logits_2 = get_logits(model, input_features)
|
||||
|
||||
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
|
||||
|
||||
model = Wav2Vec2ForCTC.from_pretrained("hf-internal-testing/tiny-random-wav2vec2-adapter")
|
||||
logits = get_logits(model, input_features)
|
||||
|
||||
model.load_adapter("eng")
|
||||
model.load_adapter("eng", use_safetensors=False)
|
||||
model.load_adapter("eng", use_safetensors=True)
|
||||
|
||||
logits_2 = get_logits(model, input_features)
|
||||
|
||||
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
@@ -1768,3 +1865,45 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
# TODO: update the tolerance after the CI moves to torch 1.10
|
||||
self.assertAlmostEqual(outputs.loss.item(), 17.7963, 2)
|
||||
|
||||
@require_torchaudio
|
||||
def test_inference_mms_1b_all(self):
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-1b-all").to(torch_device)
|
||||
processor = Wav2Vec2Processor.from_pretrained("facebook/mms-1b-all")
|
||||
|
||||
LANG_MAP = {"it": "ita", "es": "spa", "fr": "fra", "en": "eng"}
|
||||
|
||||
def run_model(lang):
|
||||
ds = load_dataset("common_voice", lang, split="test", streaming=True)
|
||||
sample = next(iter(ds))
|
||||
|
||||
wav2vec2_lang = LANG_MAP[lang]
|
||||
|
||||
model.load_adapter(wav2vec2_lang)
|
||||
processor.tokenizer.set_target_lang(wav2vec2_lang)
|
||||
|
||||
resampled_audio = torchaudio.functional.resample(
|
||||
torch.tensor(sample["audio"]["array"]), 48_000, 16_000
|
||||
).numpy()
|
||||
|
||||
inputs = processor(resampled_audio, sampling_rate=16_000, return_tensors="pt")
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(input_values, attention_mask=attention_mask).logits
|
||||
|
||||
ids = torch.argmax(outputs, dim=-1)[0]
|
||||
|
||||
transcription = processor.decode(ids)
|
||||
return transcription
|
||||
|
||||
TRANSCRIPTIONS = {
|
||||
"it": "mi hanno fatto un'offerta che non potevo proprio rifiutare",
|
||||
"es": "bien y qué regalo vas a abrir primero",
|
||||
"fr": "un vrai travail intéressant va enfin être mené sur ce sujet",
|
||||
"en": "twas the time of day and olof spen slept during the summer",
|
||||
}
|
||||
|
||||
for lang in LANG_MAP.keys():
|
||||
assert run_model(lang) == TRANSCRIPTIONS[lang]
|
||||
|
||||
@@ -772,3 +772,48 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
output = tokenizer.convert_tokens_to_string(tokens)
|
||||
|
||||
self.assertIsInstance(output["text"], str)
|
||||
|
||||
def test_nested_vocab(self):
|
||||
eng_vocab = {"a": 7, "b": 8}
|
||||
spa_vocab = {"a": 23, "c": 88}
|
||||
ita_vocab = {"a": 6, "d": 9}
|
||||
|
||||
nested_vocab = {"eng": eng_vocab, "spa": spa_vocab, "ita": ita_vocab}
|
||||
|
||||
def check_tokenizer(tokenizer, check_ita_first=False):
|
||||
if check_ita_first:
|
||||
self.assertEqual(tokenizer.decode([6, 9, 9]), "ad")
|
||||
self.assertEqual(tokenizer.encoder, ita_vocab)
|
||||
tokenizer.set_target_lang("eng")
|
||||
|
||||
self.assertEqual(tokenizer.encoder, eng_vocab)
|
||||
self.assertEqual(tokenizer.decode([7, 8, 7]), "aba")
|
||||
|
||||
tokenizer.set_target_lang("spa")
|
||||
self.assertEqual(tokenizer.decode([23, 88, 23]), "aca")
|
||||
self.assertEqual(tokenizer.encoder, spa_vocab)
|
||||
|
||||
tokenizer.set_target_lang("eng")
|
||||
self.assertEqual(tokenizer.encoder, eng_vocab)
|
||||
self.assertEqual(tokenizer.decode([7, 7, 8]), "ab")
|
||||
|
||||
tokenizer.set_target_lang("ita")
|
||||
self.assertEqual(tokenizer.decode([6, 9, 9]), "ad")
|
||||
self.assertEqual(tokenizer.encoder, ita_vocab)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
tempfile_path = os.path.join(tempdir, "vocab.json")
|
||||
with open(tempfile_path, "w") as temp_file:
|
||||
json.dump(nested_vocab, temp_file)
|
||||
|
||||
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(tempdir, target_lang="eng")
|
||||
|
||||
check_tokenizer(tokenizer)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
# should have saved target lang as "ita" since it was last one
|
||||
tokenizer.save_pretrained(tempdir)
|
||||
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(tempdir)
|
||||
|
||||
self.assertEqual(tokenizer.target_lang, "ita")
|
||||
check_tokenizer(tokenizer, check_ita_first=True)
|
||||
|
||||
Reference in New Issue
Block a user