Fix bug in Wav2Vec2's GPU tests (#19803)

* Fix tests when running on GPU

* Fix tests that require mp.set_start_method
This commit is contained in:
Antonio Carlos Falcão Petri
2022-10-27 10:00:03 -03:00
committed by GitHub
parent f1e42bc50e
commit ea118ae2e1
3 changed files with 158 additions and 65 deletions

View File

@@ -19,6 +19,7 @@ import multiprocessing
import os
import pickle
import tempfile
import traceback
import unittest
import numpy as np
@@ -34,6 +35,7 @@ from transformers.testing_utils import (
require_soundfile,
require_torch,
require_torchaudio,
run_test_in_subprocess,
slow,
torch_device,
)
@@ -75,6 +77,7 @@ if is_torchaudio_available():
if is_pyctcdecode_available():
import pyctcdecode.decoder
from transformers import Wav2Vec2ProcessorWithLM
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
@@ -83,6 +86,51 @@ if is_torch_fx_available():
from transformers.utils.fx import symbolic_trace
def _test_wav2vec2_with_lm_invalid_pool(in_queue, out_queue, timeout):
error = None
try:
_ = in_queue.get(timeout=timeout)
ds = load_dataset("common_voice", "es", split="test", streaming=True)
sample = next(iter(ds))
resampled_audio = torchaudio.functional.resample(
torch.tensor(sample["audio"]["array"]), 48_000, 16_000
).numpy()
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm").to(
torch_device
)
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(resampled_audio, return_tensors="pt").input_values
with torch.no_grad():
logits = model(input_values.to(torch_device)).logits
# use a spawn pool, which should trigger a warning if different than fork
with CaptureLogger(pyctcdecode.decoder.logger) as cl, multiprocessing.get_context("spawn").Pool(1) as pool:
transcription = processor.batch_decode(logits.cpu().numpy(), pool).text
unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out)
unittest.TestCase().assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
# force batch_decode to internally create a spawn pool, which should trigger a warning if different than fork
multiprocessing.set_start_method("spawn", force=True)
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
transcription = processor.batch_decode(logits.cpu().numpy()).text
unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out)
unittest.TestCase().assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
except Exception:
error = f"{traceback.format_exc()}"
results = {"error": error}
out_queue.put(results, timeout=timeout)
out_queue.join()
class Wav2Vec2ModelTester:
def __init__(
self,
@@ -1636,7 +1684,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
# test user-managed pool
with multiprocessing.get_context("fork").Pool(2) as pool:
transcription = processor.batch_decode(logits.numpy(), pool).text
transcription = processor.batch_decode(logits.cpu().numpy(), pool).text
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
@@ -1644,7 +1692,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
2
) as pool:
transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text
transcription = processor.batch_decode(logits.cpu().numpy(), pool, num_processes=2).text
self.assertIn("num_process", cl.out)
self.assertIn("it will be ignored", cl.out)
@@ -1654,30 +1702,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
@require_pyctcdecode
@require_torchaudio
def test_wav2vec2_with_lm_invalid_pool(self):
ds = load_dataset("common_voice", "es", split="test", streaming=True)
sample = next(iter(ds))
resampled_audio = torchaudio.functional.resample(
torch.tensor(sample["audio"]["array"]), 48_000, 16_000
).numpy()
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm").to(
torch_device
timeout = os.environ.get("PYTEST_TIMEOUT", 600)
run_test_in_subprocess(
test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None, timeout=timeout
)
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(resampled_audio, return_tensors="pt").input_values
with torch.no_grad():
logits = model(input_values.to(torch_device)).logits
# change default start method, which should trigger a warning if different than fork
multiprocessing.set_start_method("spawn")
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
transcription = processor.batch_decode(logits.numpy()).text
self.assertIn("Falling back to sequential decoding.", cl.out)
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
def test_inference_diarization(self):
model = Wav2Vec2ForAudioFrameClassification.from_pretrained("anton-l/wav2vec2-base-superb-sd").to(torch_device)