From 343684210289706c376c85819eae8d860b6a40ef Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 21 Oct 2022 21:59:18 +0200 Subject: [PATCH] Run some TF Whisper tests in subprocesses to avoid GPU OOM (#19772) * Run some TF Whisper tests in subprocesses to avoid GPU OOM Co-authored-by: ydshieh --- src/transformers/testing_utils.py | 41 +++ .../whisper/test_modeling_tf_whisper.py | 308 +++++++++++------- 2 files changed, 234 insertions(+), 115 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index eacaf61267..7bbecd3320 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -17,6 +17,7 @@ import contextlib import functools import inspect import logging +import multiprocessing import os import re import shlex @@ -1672,3 +1673,43 @@ def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None): return wrapper return decorator + + +def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=600): + """ + To run a test in a subprocess. In particular, this can avoid (GPU) memory issue. + + Args: + test_case (`unittest.TestCase`): + The test that will run `target_func`. + target_func (`Callable`): + The function implementing the actual testing logic. + inputs (`dict`, *optional*, defaults to `None`): + The inputs that will be passed to `target_func` through an (input) queue. + timeout (`int`, *optional*, defaults to 600): + The timeout (in seconds) that will be passed to the input and output queues. + """ + + start_methohd = "spawn" + ctx = multiprocessing.get_context(start_methohd) + + input_queue = ctx.Queue(1) + output_queue = ctx.JoinableQueue(1) + + # We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle. + input_queue.put(inputs, timeout=timeout) + + process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout)) + process.start() + # Kill the child process if we can't get outputs from it in time: otherwise, the hanging subprocess prevents + # the test to exit properly. + try: + results = output_queue.get(timeout=timeout) + output_queue.task_done() + except Exception as e: + process.terminate() + test_case.fail(e) + process.join(timeout=timeout) + + if results["error"] is not None: + test_case.fail(f'{results["error"]}') diff --git a/tests/models/whisper/test_modeling_tf_whisper.py b/tests/models/whisper/test_modeling_tf_whisper.py index a5503a0b2a..ae99c94087 100644 --- a/tests/models/whisper/test_modeling_tf_whisper.py +++ b/tests/models/whisper/test_modeling_tf_whisper.py @@ -15,13 +15,15 @@ """ Testing suite for the TensorFlow Whisper model. """ import inspect +import os import tempfile +import traceback import unittest import numpy as np from transformers import WhisperConfig, WhisperFeatureExtractor, WhisperProcessor -from transformers.testing_utils import is_tf_available, require_tf, require_tokenizers, slow +from transformers.testing_utils import is_tf_available, require_tf, require_tokenizers, run_test_in_subprocess, slow from transformers.utils import cached_property from transformers.utils.import_utils import is_datasets_available @@ -626,6 +628,184 @@ class TFWhisperModelTest(TFModelTesterMixin, unittest.TestCase): self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids)) +def _load_datasamples(num_samples): + + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + # automatic decoding with librispeech + speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"] + + return [x["array"] for x in speech_samples] + + +def _test_large_logits_librispeech(in_queue, out_queue, timeout): + + error = None + try: + _ = in_queue.get(timeout=timeout) + + set_seed(0) + + model = TFWhisperModel.from_pretrained("openai/whisper-large") + + input_speech = _load_datasamples(1) + + processor = WhisperProcessor.from_pretrained("openai/whisper-large") + processed_inputs = processor(audio=input_speech, text="This part of the speech", return_tensors="tf") + input_features = processed_inputs.input_features + labels = processed_inputs.labels + + logits = model( + input_features, + decoder_input_ids=labels, + output_hidden_states=False, + output_attentions=False, + use_cache=False, + ) + + logits = logits.last_hidden_state @ tf.transpose(model.model.decoder.embed_tokens.weights[0]) + + # fmt: off + EXPECTED_LOGITS = tf.convert_to_tensor( + [ + 2.1382, 0.9381, 4.4671, 3.5589, 2.4022, 3.8576, -0.6521, 2.5472, + 1.8301, 1.9957, 2.3432, 1.4678, 0.5459, 2.2597, 1.5179, 2.5357, + 1.1624, 0.6194, 1.0757, 1.8259, 2.4076, 1.6601, 2.3503, 1.3376, + 1.9891, 1.8635, 3.8931, 5.3699, 4.4772, 3.9184 + ] + ) + # fmt: on + + unittest.TestCase().assertTrue(np.allclose(logits[0, 0, :30], EXPECTED_LOGITS, atol=1e-4)) + except Exception: + error = f"{traceback.format_exc()}" + + results = {"error": error} + out_queue.put(results, timeout=timeout) + out_queue.join() + + +def _test_large_generation(in_queue, out_queue, timeout): + + error = None + try: + _ = in_queue.get(timeout=timeout) + + set_seed(0) + processor = WhisperProcessor.from_pretrained("openai/whisper-large") + model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large") + + input_speech = _load_datasamples(1) + input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features + + model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe") + generated_ids = model.generate(input_features, do_sample=False, max_length=20) + transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + + EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad" + unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT) + except Exception: + error = f"{traceback.format_exc()}" + + results = {"error": error} + out_queue.put(results, timeout=timeout) + out_queue.join() + + +def _test_large_generation_multilingual(in_queue, out_queue, timeout): + + error = None + try: + _ = in_queue.get(timeout=timeout) + + set_seed(0) + processor = WhisperProcessor.from_pretrained("openai/whisper-large") + model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large") + + ds = load_dataset("common_voice", "ja", split="test", streaming=True) + ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000)) + input_speech = next(iter(ds))["audio"]["array"] + input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features + + model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe") + generated_ids = model.generate(input_features, do_sample=False, max_length=20) + transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + + EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました" + unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT) + + model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe") + generated_ids = model.generate( + input_features, + do_sample=False, + max_length=20, + ) + transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + + EXPECTED_TRANSCRIPT = " Kimura san ni denwa wo kaite moraimashita" + unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT) + + model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate") + generated_ids = model.generate(input_features, do_sample=False, max_length=20) + transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + + EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san" + unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT) + except Exception: + error = f"{traceback.format_exc()}" + + results = {"error": error} + out_queue.put(results, timeout=timeout) + out_queue.join() + + +def _test_large_batched_generation(in_queue, out_queue, timeout): + + error = None + try: + _ = in_queue.get(timeout=timeout) + + set_seed(0) + processor = WhisperProcessor.from_pretrained("openai/whisper-large") + model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large") + + input_speech = _load_datasamples(4) + input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features + generated_ids_1 = model.generate(input_features[0:2], max_length=20) + generated_ids_2 = model.generate(input_features[2:4], max_length=20) + generated_ids = np.concatenate([generated_ids_1, generated_ids_2]) + + # fmt: off + EXPECTED_LOGITS = tf.convert_to_tensor( + [ + [50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281], + [50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257], + [50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256], + [50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11] + ] + ) + # fmt: on + + unittest.TestCase().assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS)) + + # fmt: off + EXPECTED_TRANSCRIPT = [ + ' Mr. Quilter is the apostle of the middle classes and we are glad to', + " Nor is Mr. Quilter's manner less interesting than his matter.", + " He tells us that at this festive season of the year, with Christmas and roast beef", + " He has grave doubts whether Sir Frederick Layton's work is really Greek after all," + ] + # fmt: on + + transcript = processor.batch_decode(generated_ids, skip_special_tokens=True) + unittest.TestCase().assertListEqual(transcript, EXPECTED_TRANSCRIPT) + except Exception: + error = f"{traceback.format_exc()}" + + results = {"error": error} + out_queue.put(results, timeout=timeout) + out_queue.join() + + @require_tf @require_tokenizers class TFWhisperModelIntegrationTests(unittest.TestCase): @@ -634,12 +814,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): return WhisperProcessor.from_pretrained("openai/whisper-base") def _load_datasamples(self, num_samples): - - ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - # automatic decoding with librispeech - speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"] - - return [x["array"] for x in speech_samples] + return _load_datasamples(num_samples) @slow def test_tiny_logits_librispeech(self): @@ -719,40 +894,11 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): @slow def test_large_logits_librispeech(self): - set_seed(0) - - model = TFWhisperModel.from_pretrained("openai/whisper-large") - - input_speech = self._load_datasamples(1) - - processor = WhisperProcessor.from_pretrained("openai/whisper-large") - processed_inputs = processor(audio=input_speech, text="This part of the speech", return_tensors="tf") - input_features = processed_inputs.input_features - labels = processed_inputs.labels - - logits = model( - input_features, - decoder_input_ids=labels, - output_hidden_states=False, - output_attentions=False, - use_cache=False, + timeout = os.environ.get("PYTEST_TIMEOUT", 600) + run_test_in_subprocess( + test_case=self, target_func=_test_large_logits_librispeech, inputs=None, timeout=timeout ) - logits = logits.last_hidden_state @ tf.transpose(model.model.decoder.embed_tokens.weights[0]) - - # fmt: off - EXPECTED_LOGITS = tf.convert_to_tensor( - [ - 2.1382, 0.9381, 4.4671, 3.5589, 2.4022, 3.8576, -0.6521, 2.5472, - 1.8301, 1.9957, 2.3432, 1.4678, 0.5459, 2.2597, 1.5179, 2.5357, - 1.1624, 0.6194, 1.0757, 1.8259, 2.4076, 1.6601, 2.3503, 1.3376, - 1.9891, 1.8635, 3.8931, 5.3699, 4.4772, 3.9184 - ] - ) - # fmt: on - - self.assertTrue(np.allclose(logits[0, 0, :30], EXPECTED_LOGITS, atol=1e-4)) - @slow def test_tiny_en_generation(self): set_seed(0) @@ -816,90 +962,22 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): @slow def test_large_generation(self): - set_seed(0) - processor = WhisperProcessor.from_pretrained("openai/whisper-large") - model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large") - - input_speech = self._load_datasamples(1) - input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features - - model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe") - generated_ids = model.generate(input_features, do_sample=False, max_length=20) - transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] - - EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad" - self.assertEqual(transcript, EXPECTED_TRANSCRIPT) + timeout = os.environ.get("PYTEST_TIMEOUT", 600) + run_test_in_subprocess(test_case=self, target_func=_test_large_generation, inputs=None, timeout=timeout) @slow def test_large_generation_multilingual(self): - set_seed(0) - processor = WhisperProcessor.from_pretrained("openai/whisper-large") - model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large") - - ds = load_dataset("common_voice", "ja", split="test", streaming=True) - ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000)) - input_speech = next(iter(ds))["audio"]["array"] - input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features - - model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe") - generated_ids = model.generate(input_features, do_sample=False, max_length=20) - transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] - - EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました" - self.assertEqual(transcript, EXPECTED_TRANSCRIPT) - - model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe") - generated_ids = model.generate( - input_features, - do_sample=False, - max_length=20, + timeout = os.environ.get("PYTEST_TIMEOUT", 600) + run_test_in_subprocess( + test_case=self, target_func=_test_large_generation_multilingual, inputs=None, timeout=timeout ) - transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] - - EXPECTED_TRANSCRIPT = " Kimura san ni denwa wo kaite moraimashita" - self.assertEqual(transcript, EXPECTED_TRANSCRIPT) - - model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate") - generated_ids = model.generate(input_features, do_sample=False, max_length=20) - transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] - - EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san" - self.assertEqual(transcript, EXPECTED_TRANSCRIPT) @slow def test_large_batched_generation(self): - set_seed(0) - processor = WhisperProcessor.from_pretrained("openai/whisper-large") - model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large") - - input_speech = self._load_datasamples(4) - input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features - generated_ids = model.generate(input_features, max_length=20) - - # fmt: off - EXPECTED_LOGITS = tf.convert_to_tensor( - [ - [50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281], - [50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257], - [50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256], - [50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11] - ] + timeout = os.environ.get("PYTEST_TIMEOUT", 600) + run_test_in_subprocess( + test_case=self, target_func=_test_large_batched_generation, inputs=None, timeout=timeout ) - # fmt: on - - self.assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS)) - - # fmt: off - EXPECTED_TRANSCRIPT = [ - ' Mr. Quilter is the apostle of the middle classes and we are glad to', - " Nor is Mr. Quilter's manner less interesting than his matter.", - " He tells us that at this festive season of the year, with Christmas and roast beef", - " He has grave doubts whether Sir Frederick Layton's work is really Greek after all," - ] - # fmt: on - - transcript = processor.batch_decode(generated_ids, skip_special_tokens=True) - self.assertListEqual(transcript, EXPECTED_TRANSCRIPT) @slow def test_tiny_en_batched_generation(self):