Fix Wav2Vec2 CI OOM (#24190)

fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-06-12 11:39:04 +02:00
committed by GitHub
parent 8f093fb799
commit e26c6f03be
2 changed files with 13 additions and 0 deletions

View File

@@ -17,6 +17,7 @@
from __future__ import annotations from __future__ import annotations
import copy import copy
import gc
import glob import glob
import inspect import inspect
import math import math
@@ -709,6 +710,11 @@ class TFWav2Vec2UtilsTest(unittest.TestCase):
@require_tf @require_tf
@slow @slow
class TFWav2Vec2ModelIntegrationTest(unittest.TestCase): class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
def _load_datasamples(self, num_samples): def _load_datasamples(self, num_samples):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech # automatic decoding with librispeech

View File

@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch Wav2Vec2 model. """ """ Testing suite for the PyTorch Wav2Vec2 model. """
import gc
import math import math
import multiprocessing import multiprocessing
import os import os
@@ -1374,6 +1375,12 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
@require_soundfile @require_soundfile
@slow @slow
class Wav2Vec2ModelIntegrationTest(unittest.TestCase): class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
def _load_datasamples(self, num_samples): def _load_datasamples(self, num_samples):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech # automatic decoding with librispeech