[Tests] Correct Wav2Vec2 & WavLM tests (#15015)

* up

* up

* up
This commit is contained in:
Patrick von Platen
2022-01-03 20:19:04 +01:00
committed by GitHub
parent 0b4c3a1a53
commit dbac8899fe
3 changed files with 13 additions and 35 deletions

View File

@@ -290,7 +290,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
apt -y update && apt install -y libsndfile1-dev git apt -y update && apt install -y libsndfile1-dev git espeak-ng
pip install --upgrade pip pip install --upgrade pip
pip install .[sklearn,testing,onnx,sentencepiece,tf-speech,vision] pip install .[sklearn,testing,onnx,sentencepiece,tf-speech,vision]
pip install https://github.com/kpu/kenlm/archive/master.zip pip install https://github.com/kpu/kenlm/archive/master.zip

View File

@@ -15,6 +15,7 @@
import copy import copy
import glob
import inspect import inspect
import math import math
import unittest import unittest
@@ -23,6 +24,7 @@ import numpy as np
import pytest import pytest
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import snapshot_download
from transformers import Wav2Vec2Config, is_tf_available from transformers import Wav2Vec2Config, is_tf_available
from transformers.file_utils import is_librosa_available, is_pyctcdecode_available from transformers.file_utils import is_librosa_available, is_pyctcdecode_available
from transformers.testing_utils import require_librosa, require_pyctcdecode, require_tf, slow from transformers.testing_utils import require_librosa, require_pyctcdecode, require_tf, slow
@@ -485,8 +487,6 @@ class TFWav2Vec2UtilsTest(unittest.TestCase):
@slow @slow
class TFWav2Vec2ModelIntegrationTest(unittest.TestCase): class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
def _load_datasamples(self, num_samples): def _load_datasamples(self, num_samples):
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech # automatic decoding with librispeech
speech_samples = ds.sort("id").filter( speech_samples = ds.sort("id").filter(
@@ -556,18 +556,17 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
@require_pyctcdecode @require_pyctcdecode
@require_librosa @require_librosa
def test_wav2vec2_with_lm(self): def test_wav2vec2_with_lm(self):
ds = load_dataset("common_voice", "es", split="test", streaming=True) downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample")
sample = next(iter(ds)) file_path = glob.glob(downloaded_folder + "/*")[0]
sample = librosa.load(file_path, sr=16_000)[0]
resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000)
model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(resampled_audio, return_tensors="tf").input_values input_values = processor(sample, return_tensors="tf").input_values
logits = model(input_values).logits logits = model(input_values).logits
transcription = processor.batch_decode(logits.numpy()).text transcription = processor.batch_decode(logits.numpy()).text
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch WavLM model. """ """ Testing suite for the PyTorch WavLM model. """
import copy
import math import math
import unittest import unittest
@@ -452,30 +451,9 @@ class WavLMModelTest(ModelTesterMixin, unittest.TestCase):
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
module.masked_spec_embed.data.fill_(3) module.masked_spec_embed.data.fill_(3)
# overwrite from test_modeling_common @unittest.skip(reason="Feed forward chunking is not implemented for WavLM")
# as WavLM is not very precise
def test_feed_forward_chunking(self): def test_feed_forward_chunking(self):
( pass
original_config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
torch.manual_seed(0)
config = copy.deepcopy(original_config)
model = model_class(config)
model.to(torch_device)
model.eval()
hidden_states_no_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
torch.manual_seed(0)
config.chunk_size_feed_forward = 1
model = model_class(config)
model.to(torch_device)
model.eval()
hidden_states_with_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-2))
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
@@ -528,7 +506,7 @@ class WavLMModelIntegrationTest(unittest.TestCase):
def test_inference_large(self): def test_inference_large(self):
model = WavLMModel.from_pretrained("microsoft/wavlm-large").to(torch_device) model = WavLMModel.from_pretrained("microsoft/wavlm-large").to(torch_device)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"microsoft/wavlm-base-plus", return_attention_mask=True "microsoft/wavlm-large", return_attention_mask=True
) )
input_speech = self._load_datasamples(2) input_speech = self._load_datasamples(2)
@@ -544,8 +522,9 @@ class WavLMModelIntegrationTest(unittest.TestCase):
) )
EXPECTED_HIDDEN_STATES_SLICE = torch.tensor( EXPECTED_HIDDEN_STATES_SLICE = torch.tensor(
[[[0.1612, 0.4314], [0.1690, 0.4344]], [[0.2086, 0.1396], [0.3014, 0.0903]]] [[[0.2122, 0.0500], [0.2118, 0.0563]], [[0.1353, 0.1818], [0.2453, 0.0595]]]
) )
self.assertTrue(torch.allclose(hidden_states_slice, EXPECTED_HIDDEN_STATES_SLICE, rtol=5e-2)) self.assertTrue(torch.allclose(hidden_states_slice, EXPECTED_HIDDEN_STATES_SLICE, rtol=5e-2))
def test_inference_diarization(self): def test_inference_diarization(self):