From 2e2088f24b60d8817c74c32a0ac6bb1c5d39544d Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Thu, 8 Jun 2023 18:21:09 +0200 Subject: [PATCH] Avoid `GPT-2` daily CI job OOM (in TF tests) (#24106) * fix * fix --------- Co-authored-by: ydshieh --- tests/models/gpt2/test_modeling_gpt2.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/models/gpt2/test_modeling_gpt2.py b/tests/models/gpt2/test_modeling_gpt2.py index 0575b74e4c..65542b4954 100644 --- a/tests/models/gpt2/test_modeling_gpt2.py +++ b/tests/models/gpt2/test_modeling_gpt2.py @@ -15,6 +15,7 @@ import datetime +import gc import math import unittest @@ -500,6 +501,12 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin self.model_tester = GPT2ModelTester(self) self.config_tester = ConfigTester(self, config_class=GPT2Config, n_embd=37) + def tearDown(self): + super().tearDown() + # clean-up as much as possible GPU memory occupied by PyTorch + gc.collect() + torch.cuda.empty_cache() + def test_config(self): self.config_tester.run_common_tests() @@ -683,6 +690,12 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin @require_torch class GPT2ModelLanguageGenerationTest(unittest.TestCase): + def tearDown(self): + super().tearDown() + # clean-up as much as possible GPU memory occupied by PyTorch + gc.collect() + torch.cuda.empty_cache() + def _test_lm_generate_gpt2_helper( self, gradient_checkpointing=False,