From 1d2b57b8a21a6cdfb7706a0c607dc16ee603dbf3 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Wed, 1 Jun 2022 16:27:23 +0200 Subject: [PATCH] Fix CTRL tests (#17508) * fix Co-authored-by: ydshieh --- tests/models/ctrl/test_modeling_ctrl.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/models/ctrl/test_modeling_ctrl.py b/tests/models/ctrl/test_modeling_ctrl.py index 0256a5718b..ad6652f882 100644 --- a/tests/models/ctrl/test_modeling_ctrl.py +++ b/tests/models/ctrl/test_modeling_ctrl.py @@ -13,6 +13,7 @@ # limitations under the License. +import gc import unittest from transformers import CTRLConfig, is_torch_available @@ -181,6 +182,12 @@ class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): self.model_tester = CTRLModelTester(self) self.config_tester = ConfigTester(self, config_class=CTRLConfig, n_embd=37) + def tearDown(self): + super().tearDown() + # clean-up as much as possible GPU memory occupied by PyTorch + gc.collect() + torch.cuda.empty_cache() + def test_config(self): self.config_tester.run_common_tests() @@ -201,6 +208,12 @@ class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): @require_torch class CTRLModelLanguageGenerationTest(unittest.TestCase): + def tearDown(self): + super().tearDown() + # clean-up as much as possible GPU memory occupied by PyTorch + gc.collect() + torch.cuda.empty_cache() + @slow def test_lm_generate_ctrl(self): model = CTRLLMHeadModel.from_pretrained("ctrl")