Fix bug in Wav2Vec2's GPU tests (#19803)
* Fix tests when running on GPU * Fix tests that require mp.set_start_method
This commit is contained in:
committed by
GitHub
parent
f1e42bc50e
commit
ea118ae2e1
@@ -15,6 +15,8 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import math
|
import math
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import traceback
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -31,6 +33,7 @@ from transformers.testing_utils import (
|
|||||||
require_librosa,
|
require_librosa,
|
||||||
require_pyctcdecode,
|
require_pyctcdecode,
|
||||||
require_soundfile,
|
require_soundfile,
|
||||||
|
run_test_in_subprocess,
|
||||||
slow,
|
slow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -54,6 +57,7 @@ if is_flax_available():
|
|||||||
|
|
||||||
|
|
||||||
if is_pyctcdecode_available():
|
if is_pyctcdecode_available():
|
||||||
|
import pyctcdecode.decoder
|
||||||
from transformers import Wav2Vec2ProcessorWithLM
|
from transformers import Wav2Vec2ProcessorWithLM
|
||||||
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
|
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
|
||||||
|
|
||||||
@@ -62,6 +66,46 @@ if is_librosa_available():
|
|||||||
import librosa
|
import librosa
|
||||||
|
|
||||||
|
|
||||||
|
def _test_wav2vec2_with_lm_invalid_pool(in_queue, out_queue, timeout):
|
||||||
|
|
||||||
|
error = None
|
||||||
|
try:
|
||||||
|
_ = in_queue.get(timeout=timeout)
|
||||||
|
|
||||||
|
ds = load_dataset("common_voice", "es", split="test", streaming=True)
|
||||||
|
sample = next(iter(ds))
|
||||||
|
|
||||||
|
resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000)
|
||||||
|
|
||||||
|
model = FlaxWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||||
|
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||||
|
|
||||||
|
input_values = processor(resampled_audio, return_tensors="np").input_values
|
||||||
|
|
||||||
|
logits = model(input_values).logits
|
||||||
|
|
||||||
|
# use a spawn pool, which should trigger a warning if different than fork
|
||||||
|
with CaptureLogger(pyctcdecode.decoder.logger) as cl, multiprocessing.get_context("spawn").Pool(1) as pool:
|
||||||
|
transcription = processor.batch_decode(np.array(logits), pool).text
|
||||||
|
|
||||||
|
unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out)
|
||||||
|
unittest.TestCase().assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||||
|
|
||||||
|
# force batch_decode to internally create a spawn pool, which should trigger a warning if different than fork
|
||||||
|
multiprocessing.set_start_method("spawn", force=True)
|
||||||
|
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
|
||||||
|
transcription = processor.batch_decode(np.array(logits)).text
|
||||||
|
|
||||||
|
unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out)
|
||||||
|
unittest.TestCase().assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||||
|
except Exception:
|
||||||
|
error = f"{traceback.format_exc()}"
|
||||||
|
|
||||||
|
results = {"error": error}
|
||||||
|
out_queue.put(results, timeout=timeout)
|
||||||
|
out_queue.join()
|
||||||
|
|
||||||
|
|
||||||
class FlaxWav2Vec2ModelTester:
|
class FlaxWav2Vec2ModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -575,7 +619,7 @@ class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
# test user-managed pool
|
# test user-managed pool
|
||||||
with multiprocessing.get_context("fork").Pool(2) as pool:
|
with multiprocessing.get_context("fork").Pool(2) as pool:
|
||||||
transcription = processor.batch_decode(logits.numpy(), pool).text
|
transcription = processor.batch_decode(np.array(logits), pool).text
|
||||||
|
|
||||||
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||||
|
|
||||||
@@ -583,7 +627,7 @@ class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
|
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
|
||||||
2
|
2
|
||||||
) as pool:
|
) as pool:
|
||||||
transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text
|
transcription = processor.batch_decode(np.array(logits), pool, num_processes=2).text
|
||||||
|
|
||||||
self.assertIn("num_process", cl.out)
|
self.assertIn("num_process", cl.out)
|
||||||
self.assertIn("it will be ignored", cl.out)
|
self.assertIn("it will be ignored", cl.out)
|
||||||
@@ -593,22 +637,7 @@ class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
@require_pyctcdecode
|
@require_pyctcdecode
|
||||||
@require_librosa
|
@require_librosa
|
||||||
def test_wav2vec2_with_lm_invalid_pool(self):
|
def test_wav2vec2_with_lm_invalid_pool(self):
|
||||||
ds = load_dataset("common_voice", "es", split="test", streaming=True)
|
timeout = os.environ.get("PYTEST_TIMEOUT", 600)
|
||||||
sample = next(iter(ds))
|
run_test_in_subprocess(
|
||||||
|
test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None, timeout=timeout
|
||||||
resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000)
|
)
|
||||||
|
|
||||||
model = FlaxWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
|
||||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
|
||||||
|
|
||||||
input_values = processor(resampled_audio, return_tensors="np").input_values
|
|
||||||
|
|
||||||
logits = model(input_values).logits
|
|
||||||
|
|
||||||
# change default start method, which should trigger a warning if different than fork
|
|
||||||
multiprocessing.set_start_method("spawn")
|
|
||||||
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
|
|
||||||
transcription = processor.batch_decode(logits.numpy()).text
|
|
||||||
|
|
||||||
self.assertIn("Falling back to sequential decoding.", cl.out)
|
|
||||||
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ import glob
|
|||||||
import inspect
|
import inspect
|
||||||
import math
|
import math
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import traceback
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -27,7 +29,15 @@ from datasets import load_dataset
|
|||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from transformers import Wav2Vec2Config, is_tf_available
|
from transformers import Wav2Vec2Config, is_tf_available
|
||||||
from transformers.testing_utils import CaptureLogger, is_flaky, require_librosa, require_pyctcdecode, require_tf, slow
|
from transformers.testing_utils import (
|
||||||
|
CaptureLogger,
|
||||||
|
is_flaky,
|
||||||
|
require_librosa,
|
||||||
|
require_pyctcdecode,
|
||||||
|
require_tf,
|
||||||
|
run_test_in_subprocess,
|
||||||
|
slow,
|
||||||
|
)
|
||||||
from transformers.utils import is_librosa_available, is_pyctcdecode_available
|
from transformers.utils import is_librosa_available, is_pyctcdecode_available
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@@ -42,6 +52,7 @@ if is_tf_available():
|
|||||||
|
|
||||||
|
|
||||||
if is_pyctcdecode_available():
|
if is_pyctcdecode_available():
|
||||||
|
import pyctcdecode.decoder
|
||||||
from transformers import Wav2Vec2ProcessorWithLM
|
from transformers import Wav2Vec2ProcessorWithLM
|
||||||
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
|
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
|
||||||
|
|
||||||
@@ -50,6 +61,45 @@ if is_librosa_available():
|
|||||||
import librosa
|
import librosa
|
||||||
|
|
||||||
|
|
||||||
|
def _test_wav2vec2_with_lm_invalid_pool(in_queue, out_queue, timeout):
|
||||||
|
|
||||||
|
error = None
|
||||||
|
try:
|
||||||
|
_ = in_queue.get(timeout=timeout)
|
||||||
|
|
||||||
|
downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample")
|
||||||
|
file_path = glob.glob(downloaded_folder + "/*")[0]
|
||||||
|
sample = librosa.load(file_path, sr=16_000)[0]
|
||||||
|
|
||||||
|
model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||||
|
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||||
|
|
||||||
|
input_values = processor(sample, return_tensors="tf").input_values
|
||||||
|
|
||||||
|
logits = model(input_values).logits
|
||||||
|
|
||||||
|
# use a spawn pool, which should trigger a warning if different than fork
|
||||||
|
with CaptureLogger(pyctcdecode.decoder.logger) as cl, multiprocessing.get_context("spawn").Pool(1) as pool:
|
||||||
|
transcription = processor.batch_decode(logits.numpy(), pool).text
|
||||||
|
|
||||||
|
unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out)
|
||||||
|
unittest.TestCase().assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
|
||||||
|
|
||||||
|
# force batch_decode to internally create a spawn pool, which should trigger a warning if different than fork
|
||||||
|
multiprocessing.set_start_method("spawn", force=True)
|
||||||
|
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
|
||||||
|
transcription = processor.batch_decode(logits.numpy()).text
|
||||||
|
|
||||||
|
unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out)
|
||||||
|
unittest.TestCase().assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
|
||||||
|
except Exception:
|
||||||
|
error = f"{traceback.format_exc()}"
|
||||||
|
|
||||||
|
results = {"error": error}
|
||||||
|
out_queue.put(results, timeout=timeout)
|
||||||
|
out_queue.join()
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFWav2Vec2ModelTester:
|
class TFWav2Vec2ModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -627,21 +677,7 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
@require_pyctcdecode
|
@require_pyctcdecode
|
||||||
@require_librosa
|
@require_librosa
|
||||||
def test_wav2vec2_with_lm_invalid_pool(self):
|
def test_wav2vec2_with_lm_invalid_pool(self):
|
||||||
downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample")
|
timeout = os.environ.get("PYTEST_TIMEOUT", 600)
|
||||||
file_path = glob.glob(downloaded_folder + "/*")[0]
|
run_test_in_subprocess(
|
||||||
sample = librosa.load(file_path, sr=16_000)[0]
|
test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None, timeout=timeout
|
||||||
|
)
|
||||||
model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
|
||||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
|
||||||
|
|
||||||
input_values = processor(sample, return_tensors="tf").input_values
|
|
||||||
|
|
||||||
logits = model(input_values).logits
|
|
||||||
|
|
||||||
# change default start method, which should trigger a warning if different than fork
|
|
||||||
multiprocessing.set_start_method("spawn")
|
|
||||||
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
|
|
||||||
transcription = processor.batch_decode(logits.numpy()).text
|
|
||||||
|
|
||||||
self.assertIn("Falling back to sequential decoding.", cl.out)
|
|
||||||
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import multiprocessing
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import traceback
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -34,6 +35,7 @@ from transformers.testing_utils import (
|
|||||||
require_soundfile,
|
require_soundfile,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torchaudio,
|
require_torchaudio,
|
||||||
|
run_test_in_subprocess,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
@@ -75,6 +77,7 @@ if is_torchaudio_available():
|
|||||||
|
|
||||||
|
|
||||||
if is_pyctcdecode_available():
|
if is_pyctcdecode_available():
|
||||||
|
import pyctcdecode.decoder
|
||||||
from transformers import Wav2Vec2ProcessorWithLM
|
from transformers import Wav2Vec2ProcessorWithLM
|
||||||
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
|
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
|
||||||
|
|
||||||
@@ -83,6 +86,51 @@ if is_torch_fx_available():
|
|||||||
from transformers.utils.fx import symbolic_trace
|
from transformers.utils.fx import symbolic_trace
|
||||||
|
|
||||||
|
|
||||||
|
def _test_wav2vec2_with_lm_invalid_pool(in_queue, out_queue, timeout):
|
||||||
|
|
||||||
|
error = None
|
||||||
|
try:
|
||||||
|
_ = in_queue.get(timeout=timeout)
|
||||||
|
|
||||||
|
ds = load_dataset("common_voice", "es", split="test", streaming=True)
|
||||||
|
sample = next(iter(ds))
|
||||||
|
|
||||||
|
resampled_audio = torchaudio.functional.resample(
|
||||||
|
torch.tensor(sample["audio"]["array"]), 48_000, 16_000
|
||||||
|
).numpy()
|
||||||
|
|
||||||
|
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm").to(
|
||||||
|
torch_device
|
||||||
|
)
|
||||||
|
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||||
|
|
||||||
|
input_values = processor(resampled_audio, return_tensors="pt").input_values
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(input_values.to(torch_device)).logits
|
||||||
|
|
||||||
|
# use a spawn pool, which should trigger a warning if different than fork
|
||||||
|
with CaptureLogger(pyctcdecode.decoder.logger) as cl, multiprocessing.get_context("spawn").Pool(1) as pool:
|
||||||
|
transcription = processor.batch_decode(logits.cpu().numpy(), pool).text
|
||||||
|
|
||||||
|
unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out)
|
||||||
|
unittest.TestCase().assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||||
|
|
||||||
|
# force batch_decode to internally create a spawn pool, which should trigger a warning if different than fork
|
||||||
|
multiprocessing.set_start_method("spawn", force=True)
|
||||||
|
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
|
||||||
|
transcription = processor.batch_decode(logits.cpu().numpy()).text
|
||||||
|
|
||||||
|
unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out)
|
||||||
|
unittest.TestCase().assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||||
|
except Exception:
|
||||||
|
error = f"{traceback.format_exc()}"
|
||||||
|
|
||||||
|
results = {"error": error}
|
||||||
|
out_queue.put(results, timeout=timeout)
|
||||||
|
out_queue.join()
|
||||||
|
|
||||||
|
|
||||||
class Wav2Vec2ModelTester:
|
class Wav2Vec2ModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -1636,7 +1684,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
# test user-managed pool
|
# test user-managed pool
|
||||||
with multiprocessing.get_context("fork").Pool(2) as pool:
|
with multiprocessing.get_context("fork").Pool(2) as pool:
|
||||||
transcription = processor.batch_decode(logits.numpy(), pool).text
|
transcription = processor.batch_decode(logits.cpu().numpy(), pool).text
|
||||||
|
|
||||||
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||||
|
|
||||||
@@ -1644,7 +1692,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
|
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
|
||||||
2
|
2
|
||||||
) as pool:
|
) as pool:
|
||||||
transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text
|
transcription = processor.batch_decode(logits.cpu().numpy(), pool, num_processes=2).text
|
||||||
|
|
||||||
self.assertIn("num_process", cl.out)
|
self.assertIn("num_process", cl.out)
|
||||||
self.assertIn("it will be ignored", cl.out)
|
self.assertIn("it will be ignored", cl.out)
|
||||||
@@ -1654,30 +1702,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
@require_pyctcdecode
|
@require_pyctcdecode
|
||||||
@require_torchaudio
|
@require_torchaudio
|
||||||
def test_wav2vec2_with_lm_invalid_pool(self):
|
def test_wav2vec2_with_lm_invalid_pool(self):
|
||||||
ds = load_dataset("common_voice", "es", split="test", streaming=True)
|
timeout = os.environ.get("PYTEST_TIMEOUT", 600)
|
||||||
sample = next(iter(ds))
|
run_test_in_subprocess(
|
||||||
|
test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None, timeout=timeout
|
||||||
resampled_audio = torchaudio.functional.resample(
|
|
||||||
torch.tensor(sample["audio"]["array"]), 48_000, 16_000
|
|
||||||
).numpy()
|
|
||||||
|
|
||||||
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm").to(
|
|
||||||
torch_device
|
|
||||||
)
|
)
|
||||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
|
||||||
|
|
||||||
input_values = processor(resampled_audio, return_tensors="pt").input_values
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
logits = model(input_values.to(torch_device)).logits
|
|
||||||
|
|
||||||
# change default start method, which should trigger a warning if different than fork
|
|
||||||
multiprocessing.set_start_method("spawn")
|
|
||||||
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
|
|
||||||
transcription = processor.batch_decode(logits.numpy()).text
|
|
||||||
|
|
||||||
self.assertIn("Falling back to sequential decoding.", cl.out)
|
|
||||||
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
|
|
||||||
|
|
||||||
def test_inference_diarization(self):
|
def test_inference_diarization(self):
|
||||||
model = Wav2Vec2ForAudioFrameClassification.from_pretrained("anton-l/wav2vec2-base-superb-sd").to(torch_device)
|
model = Wav2Vec2ForAudioFrameClassification.from_pretrained("anton-l/wav2vec2-base-superb-sd").to(torch_device)
|
||||||
|
|||||||
Reference in New Issue
Block a user