From 847b47c0eed4e6ab904f584fb415e3d3a397867f Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 9 Jun 2023 15:20:59 +0200 Subject: [PATCH] Fix XGLM OOM on CI (#24123) fix Co-authored-by: ydshieh --- tests/models/xglm/test_modeling_tf_xglm.py | 6 ++++++ tests/models/xglm/test_modeling_xglm.py | 7 +++++++ 2 files changed, 13 insertions(+) diff --git a/tests/models/xglm/test_modeling_tf_xglm.py b/tests/models/xglm/test_modeling_tf_xglm.py index e2b8cc2e6c..3582209cc7 100644 --- a/tests/models/xglm/test_modeling_tf_xglm.py +++ b/tests/models/xglm/test_modeling_tf_xglm.py @@ -15,6 +15,7 @@ from __future__ import annotations +import gc import unittest from transformers import XGLMConfig, XGLMTokenizer, is_tf_available @@ -190,6 +191,11 @@ class TFXGLMModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase @require_tf class TFXGLMModelLanguageGenerationTest(unittest.TestCase): + def tearDown(self): + super().tearDown() + # clean-up as much as possible GPU memory occupied by PyTorch + gc.collect() + @slow def test_lm_generate_xglm(self, verify_outputs=True): model = TFXGLMForCausalLM.from_pretrained("facebook/xglm-564M") diff --git a/tests/models/xglm/test_modeling_xglm.py b/tests/models/xglm/test_modeling_xglm.py index 5028c30ea9..bbb87abe6d 100644 --- a/tests/models/xglm/test_modeling_xglm.py +++ b/tests/models/xglm/test_modeling_xglm.py @@ -14,6 +14,7 @@ # limitations under the License. import datetime +import gc import math import unittest @@ -349,6 +350,12 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin @require_torch class XGLMModelLanguageGenerationTest(unittest.TestCase): + def tearDown(self): + super().tearDown() + # clean-up as much as possible GPU memory occupied by PyTorch + gc.collect() + torch.cuda.empty_cache() + def _test_lm_generate_xglm_helper( self, gradient_checkpointing=False,