From ea118ae2e1ef62e909626f1b5a4487f5d1cb4a55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Carlos=20Falc=C3=A3o=20Petri?= Date: Thu, 27 Oct 2022 10:00:03 -0300 Subject: [PATCH] Fix bug in Wav2Vec2's GPU tests (#19803) * Fix tests when running on GPU * Fix tests that require mp.set_start_method --- .../wav2vec2/test_modeling_flax_wav2vec2.py | 71 ++++++++++++----- .../wav2vec2/test_modeling_tf_wav2vec2.py | 74 +++++++++++++----- .../models/wav2vec2/test_modeling_wav2vec2.py | 78 +++++++++++++------ 3 files changed, 158 insertions(+), 65 deletions(-) diff --git a/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py b/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py index 29642deae0..ac1dd3bcb4 100644 --- a/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py @@ -15,6 +15,8 @@ import inspect import math import multiprocessing +import os +import traceback import unittest import numpy as np @@ -31,6 +33,7 @@ from transformers.testing_utils import ( require_librosa, require_pyctcdecode, require_soundfile, + run_test_in_subprocess, slow, ) @@ -54,6 +57,7 @@ if is_flax_available(): if is_pyctcdecode_available(): + import pyctcdecode.decoder from transformers import Wav2Vec2ProcessorWithLM from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm @@ -62,6 +66,46 @@ if is_librosa_available(): import librosa +def _test_wav2vec2_with_lm_invalid_pool(in_queue, out_queue, timeout): + + error = None + try: + _ = in_queue.get(timeout=timeout) + + ds = load_dataset("common_voice", "es", split="test", streaming=True) + sample = next(iter(ds)) + + resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000) + + model = FlaxWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") + processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") + + input_values = processor(resampled_audio, return_tensors="np").input_values + + logits = model(input_values).logits + + # use a spawn pool, which should trigger a warning if different than fork + with CaptureLogger(pyctcdecode.decoder.logger) as cl, multiprocessing.get_context("spawn").Pool(1) as pool: + transcription = processor.batch_decode(np.array(logits), pool).text + + unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out) + unittest.TestCase().assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") + + # force batch_decode to internally create a spawn pool, which should trigger a warning if different than fork + multiprocessing.set_start_method("spawn", force=True) + with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl: + transcription = processor.batch_decode(np.array(logits)).text + + unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out) + unittest.TestCase().assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") + except Exception: + error = f"{traceback.format_exc()}" + + results = {"error": error} + out_queue.put(results, timeout=timeout) + out_queue.join() + + class FlaxWav2Vec2ModelTester: def __init__( self, @@ -575,7 +619,7 @@ class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase): # test user-managed pool with multiprocessing.get_context("fork").Pool(2) as pool: - transcription = processor.batch_decode(logits.numpy(), pool).text + transcription = processor.batch_decode(np.array(logits), pool).text self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") @@ -583,7 +627,7 @@ class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase): with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool( 2 ) as pool: - transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text + transcription = processor.batch_decode(np.array(logits), pool, num_processes=2).text self.assertIn("num_process", cl.out) self.assertIn("it will be ignored", cl.out) @@ -593,22 +637,7 @@ class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase): @require_pyctcdecode @require_librosa def test_wav2vec2_with_lm_invalid_pool(self): - ds = load_dataset("common_voice", "es", split="test", streaming=True) - sample = next(iter(ds)) - - resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000) - - model = FlaxWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") - processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") - - input_values = processor(resampled_audio, return_tensors="np").input_values - - logits = model(input_values).logits - - # change default start method, which should trigger a warning if different than fork - multiprocessing.set_start_method("spawn") - with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl: - transcription = processor.batch_decode(logits.numpy()).text - - self.assertIn("Falling back to sequential decoding.", cl.out) - self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") + timeout = os.environ.get("PYTEST_TIMEOUT", 600) + run_test_in_subprocess( + test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None, timeout=timeout + ) diff --git a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py index 5c2a25f413..8f9a8f0bd7 100644 --- a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py @@ -19,6 +19,8 @@ import glob import inspect import math import multiprocessing +import os +import traceback import unittest import numpy as np @@ -27,7 +29,15 @@ from datasets import load_dataset from huggingface_hub import snapshot_download from transformers import Wav2Vec2Config, is_tf_available -from transformers.testing_utils import CaptureLogger, is_flaky, require_librosa, require_pyctcdecode, require_tf, slow +from transformers.testing_utils import ( + CaptureLogger, + is_flaky, + require_librosa, + require_pyctcdecode, + require_tf, + run_test_in_subprocess, + slow, +) from transformers.utils import is_librosa_available, is_pyctcdecode_available from ...test_configuration_common import ConfigTester @@ -42,6 +52,7 @@ if is_tf_available(): if is_pyctcdecode_available(): + import pyctcdecode.decoder from transformers import Wav2Vec2ProcessorWithLM from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm @@ -50,6 +61,45 @@ if is_librosa_available(): import librosa +def _test_wav2vec2_with_lm_invalid_pool(in_queue, out_queue, timeout): + + error = None + try: + _ = in_queue.get(timeout=timeout) + + downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample") + file_path = glob.glob(downloaded_folder + "/*")[0] + sample = librosa.load(file_path, sr=16_000)[0] + + model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") + processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") + + input_values = processor(sample, return_tensors="tf").input_values + + logits = model(input_values).logits + + # use a spawn pool, which should trigger a warning if different than fork + with CaptureLogger(pyctcdecode.decoder.logger) as cl, multiprocessing.get_context("spawn").Pool(1) as pool: + transcription = processor.batch_decode(logits.numpy(), pool).text + + unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out) + unittest.TestCase().assertEqual(transcription[0], "el libro ha sido escrito por cervantes") + + # force batch_decode to internally create a spawn pool, which should trigger a warning if different than fork + multiprocessing.set_start_method("spawn", force=True) + with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl: + transcription = processor.batch_decode(logits.numpy()).text + + unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out) + unittest.TestCase().assertEqual(transcription[0], "el libro ha sido escrito por cervantes") + except Exception: + error = f"{traceback.format_exc()}" + + results = {"error": error} + out_queue.put(results, timeout=timeout) + out_queue.join() + + @require_tf class TFWav2Vec2ModelTester: def __init__( @@ -627,21 +677,7 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase): @require_pyctcdecode @require_librosa def test_wav2vec2_with_lm_invalid_pool(self): - downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample") - file_path = glob.glob(downloaded_folder + "/*")[0] - sample = librosa.load(file_path, sr=16_000)[0] - - model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") - processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") - - input_values = processor(sample, return_tensors="tf").input_values - - logits = model(input_values).logits - - # change default start method, which should trigger a warning if different than fork - multiprocessing.set_start_method("spawn") - with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl: - transcription = processor.batch_decode(logits.numpy()).text - - self.assertIn("Falling back to sequential decoding.", cl.out) - self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes") + timeout = os.environ.get("PYTEST_TIMEOUT", 600) + run_test_in_subprocess( + test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None, timeout=timeout + ) diff --git a/tests/models/wav2vec2/test_modeling_wav2vec2.py b/tests/models/wav2vec2/test_modeling_wav2vec2.py index 4d4fa2981d..aa17f4ba63 100644 --- a/tests/models/wav2vec2/test_modeling_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_wav2vec2.py @@ -19,6 +19,7 @@ import multiprocessing import os import pickle import tempfile +import traceback import unittest import numpy as np @@ -34,6 +35,7 @@ from transformers.testing_utils import ( require_soundfile, require_torch, require_torchaudio, + run_test_in_subprocess, slow, torch_device, ) @@ -75,6 +77,7 @@ if is_torchaudio_available(): if is_pyctcdecode_available(): + import pyctcdecode.decoder from transformers import Wav2Vec2ProcessorWithLM from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm @@ -83,6 +86,51 @@ if is_torch_fx_available(): from transformers.utils.fx import symbolic_trace +def _test_wav2vec2_with_lm_invalid_pool(in_queue, out_queue, timeout): + + error = None + try: + _ = in_queue.get(timeout=timeout) + + ds = load_dataset("common_voice", "es", split="test", streaming=True) + sample = next(iter(ds)) + + resampled_audio = torchaudio.functional.resample( + torch.tensor(sample["audio"]["array"]), 48_000, 16_000 + ).numpy() + + model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm").to( + torch_device + ) + processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") + + input_values = processor(resampled_audio, return_tensors="pt").input_values + + with torch.no_grad(): + logits = model(input_values.to(torch_device)).logits + + # use a spawn pool, which should trigger a warning if different than fork + with CaptureLogger(pyctcdecode.decoder.logger) as cl, multiprocessing.get_context("spawn").Pool(1) as pool: + transcription = processor.batch_decode(logits.cpu().numpy(), pool).text + + unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out) + unittest.TestCase().assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") + + # force batch_decode to internally create a spawn pool, which should trigger a warning if different than fork + multiprocessing.set_start_method("spawn", force=True) + with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl: + transcription = processor.batch_decode(logits.cpu().numpy()).text + + unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out) + unittest.TestCase().assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") + except Exception: + error = f"{traceback.format_exc()}" + + results = {"error": error} + out_queue.put(results, timeout=timeout) + out_queue.join() + + class Wav2Vec2ModelTester: def __init__( self, @@ -1636,7 +1684,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): # test user-managed pool with multiprocessing.get_context("fork").Pool(2) as pool: - transcription = processor.batch_decode(logits.numpy(), pool).text + transcription = processor.batch_decode(logits.cpu().numpy(), pool).text self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") @@ -1644,7 +1692,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool( 2 ) as pool: - transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text + transcription = processor.batch_decode(logits.cpu().numpy(), pool, num_processes=2).text self.assertIn("num_process", cl.out) self.assertIn("it will be ignored", cl.out) @@ -1654,30 +1702,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): @require_pyctcdecode @require_torchaudio def test_wav2vec2_with_lm_invalid_pool(self): - ds = load_dataset("common_voice", "es", split="test", streaming=True) - sample = next(iter(ds)) - - resampled_audio = torchaudio.functional.resample( - torch.tensor(sample["audio"]["array"]), 48_000, 16_000 - ).numpy() - - model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm").to( - torch_device + timeout = os.environ.get("PYTEST_TIMEOUT", 600) + run_test_in_subprocess( + test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None, timeout=timeout ) - processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") - - input_values = processor(resampled_audio, return_tensors="pt").input_values - - with torch.no_grad(): - logits = model(input_values.to(torch_device)).logits - - # change default start method, which should trigger a warning if different than fork - multiprocessing.set_start_method("spawn") - with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl: - transcription = processor.batch_decode(logits.numpy()).text - - self.assertIn("Falling back to sequential decoding.", cl.out) - self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes") def test_inference_diarization(self): model = Wav2Vec2ForAudioFrameClassification.from_pretrained("anton-l/wav2vec2-base-superb-sd").to(torch_device)