From e26c6f03be35a621d26d79ae59e21ceac3ffa73e Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Mon, 12 Jun 2023 11:39:04 +0200 Subject: [PATCH] Fix `Wav2Vec2` CI OOM (#24190) fix Co-authored-by: ydshieh --- tests/models/wav2vec2/test_modeling_tf_wav2vec2.py | 6 ++++++ tests/models/wav2vec2/test_modeling_wav2vec2.py | 7 +++++++ 2 files changed, 13 insertions(+) diff --git a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py index ef4c38e2a3..391d8e8ce1 100644 --- a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py @@ -17,6 +17,7 @@ from __future__ import annotations import copy +import gc import glob import inspect import math @@ -709,6 +710,11 @@ class TFWav2Vec2UtilsTest(unittest.TestCase): @require_tf @slow class TFWav2Vec2ModelIntegrationTest(unittest.TestCase): + def tearDown(self): + super().tearDown() + # clean-up as much as possible GPU memory occupied by PyTorch + gc.collect() + def _load_datasamples(self, num_samples): ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") # automatic decoding with librispeech diff --git a/tests/models/wav2vec2/test_modeling_wav2vec2.py b/tests/models/wav2vec2/test_modeling_wav2vec2.py index cf41dd9a30..87206a4b9b 100644 --- a/tests/models/wav2vec2/test_modeling_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_wav2vec2.py @@ -14,6 +14,7 @@ # limitations under the License. """ Testing suite for the PyTorch Wav2Vec2 model. """ +import gc import math import multiprocessing import os @@ -1374,6 +1375,12 @@ class Wav2Vec2UtilsTest(unittest.TestCase): @require_soundfile @slow class Wav2Vec2ModelIntegrationTest(unittest.TestCase): + def tearDown(self): + super().tearDown() + # clean-up as much as possible GPU memory occupied by PyTorch + gc.collect() + torch.cuda.empty_cache() + def _load_datasamples(self, num_samples): ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") # automatic decoding with librispeech