Fix bug in Wav2Vec2's GPU tests (#19803)
* Fix tests when running on GPU * Fix tests that require mp.set_start_method
This commit is contained in:
committed by
GitHub
parent
f1e42bc50e
commit
ea118ae2e1
@@ -19,6 +19,8 @@ import glob
|
||||
import inspect
|
||||
import math
|
||||
import multiprocessing
|
||||
import os
|
||||
import traceback
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -27,7 +29,15 @@ from datasets import load_dataset
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import Wav2Vec2Config, is_tf_available
|
||||
from transformers.testing_utils import CaptureLogger, is_flaky, require_librosa, require_pyctcdecode, require_tf, slow
|
||||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
is_flaky,
|
||||
require_librosa,
|
||||
require_pyctcdecode,
|
||||
require_tf,
|
||||
run_test_in_subprocess,
|
||||
slow,
|
||||
)
|
||||
from transformers.utils import is_librosa_available, is_pyctcdecode_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -42,6 +52,7 @@ if is_tf_available():
|
||||
|
||||
|
||||
if is_pyctcdecode_available():
|
||||
import pyctcdecode.decoder
|
||||
from transformers import Wav2Vec2ProcessorWithLM
|
||||
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
|
||||
|
||||
@@ -50,6 +61,45 @@ if is_librosa_available():
|
||||
import librosa
|
||||
|
||||
|
||||
def _test_wav2vec2_with_lm_invalid_pool(in_queue, out_queue, timeout):
|
||||
|
||||
error = None
|
||||
try:
|
||||
_ = in_queue.get(timeout=timeout)
|
||||
|
||||
downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample")
|
||||
file_path = glob.glob(downloaded_folder + "/*")[0]
|
||||
sample = librosa.load(file_path, sr=16_000)[0]
|
||||
|
||||
model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
|
||||
input_values = processor(sample, return_tensors="tf").input_values
|
||||
|
||||
logits = model(input_values).logits
|
||||
|
||||
# use a spawn pool, which should trigger a warning if different than fork
|
||||
with CaptureLogger(pyctcdecode.decoder.logger) as cl, multiprocessing.get_context("spawn").Pool(1) as pool:
|
||||
transcription = processor.batch_decode(logits.numpy(), pool).text
|
||||
|
||||
unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out)
|
||||
unittest.TestCase().assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
|
||||
|
||||
# force batch_decode to internally create a spawn pool, which should trigger a warning if different than fork
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
|
||||
transcription = processor.batch_decode(logits.numpy()).text
|
||||
|
||||
unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out)
|
||||
unittest.TestCase().assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
|
||||
except Exception:
|
||||
error = f"{traceback.format_exc()}"
|
||||
|
||||
results = {"error": error}
|
||||
out_queue.put(results, timeout=timeout)
|
||||
out_queue.join()
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFWav2Vec2ModelTester:
|
||||
def __init__(
|
||||
@@ -627,21 +677,7 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
@require_pyctcdecode
|
||||
@require_librosa
|
||||
def test_wav2vec2_with_lm_invalid_pool(self):
|
||||
downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample")
|
||||
file_path = glob.glob(downloaded_folder + "/*")[0]
|
||||
sample = librosa.load(file_path, sr=16_000)[0]
|
||||
|
||||
model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
|
||||
input_values = processor(sample, return_tensors="tf").input_values
|
||||
|
||||
logits = model(input_values).logits
|
||||
|
||||
# change default start method, which should trigger a warning if different than fork
|
||||
multiprocessing.set_start_method("spawn")
|
||||
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
|
||||
transcription = processor.batch_decode(logits.numpy()).text
|
||||
|
||||
self.assertIn("Falling back to sequential decoding.", cl.out)
|
||||
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
|
||||
timeout = os.environ.get("PYTEST_TIMEOUT", 600)
|
||||
run_test_in_subprocess(
|
||||
test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None, timeout=timeout
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user