Fix XGLM OOM on CI (#24123)

fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-06-09 15:20:59 +02:00
committed by GitHub
parent b8fe259f16
commit 847b47c0ee
2 changed files with 13 additions and 0 deletions

View File

@@ -15,6 +15,7 @@
from __future__ import annotations from __future__ import annotations
import gc
import unittest import unittest
from transformers import XGLMConfig, XGLMTokenizer, is_tf_available from transformers import XGLMConfig, XGLMTokenizer, is_tf_available
@@ -190,6 +191,11 @@ class TFXGLMModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
@require_tf @require_tf
class TFXGLMModelLanguageGenerationTest(unittest.TestCase): class TFXGLMModelLanguageGenerationTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
@slow @slow
def test_lm_generate_xglm(self, verify_outputs=True): def test_lm_generate_xglm(self, verify_outputs=True):
model = TFXGLMForCausalLM.from_pretrained("facebook/xglm-564M") model = TFXGLMForCausalLM.from_pretrained("facebook/xglm-564M")

View File

@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import datetime import datetime
import gc
import math import math
import unittest import unittest
@@ -349,6 +350,12 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
@require_torch @require_torch
class XGLMModelLanguageGenerationTest(unittest.TestCase): class XGLMModelLanguageGenerationTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
def _test_lm_generate_xglm_helper( def _test_lm_generate_xglm_helper(
self, self,
gradient_checkpointing=False, gradient_checkpointing=False,