From 32adbb26d670c3226d1c874b2b59fc97be7e9445 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Mon, 25 Apr 2022 17:33:56 +0200 Subject: [PATCH] Fix PyTorch RAG tests GPU OOM (#16881) * add torch.cuda.empty_cache in some PT RAG tests * torch.cuda.empty_cache in tearDownModule() * tearDown() * add gc.collect() Co-authored-by: ydshieh --- tests/rag/test_modeling_rag.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/rag/test_modeling_rag.py b/tests/rag/test_modeling_rag.py index a7d1288f51..6914318cfa 100644 --- a/tests/rag/test_modeling_rag.py +++ b/tests/rag/test_modeling_rag.py @@ -14,6 +14,7 @@ # limitations under the License. +import gc import json import os import shutil @@ -195,6 +196,10 @@ class RagTestMixin: def tearDown(self): shutil.rmtree(self.tmpdirname) + # clean-up as much as possible GPU memory occupied by PyTorch + gc.collect() + torch.cuda.empty_cache() + def get_retriever(self, config): dataset = Dataset.from_dict( { @@ -677,6 +682,12 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase): @require_tokenizers @require_torch_non_multi_gpu class RagModelIntegrationTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + # clean-up as much as possible GPU memory occupied by PyTorch + gc.collect() + torch.cuda.empty_cache() + @cached_property def sequence_model(self): return ( @@ -1024,6 +1035,12 @@ class RagModelIntegrationTests(unittest.TestCase): @require_torch @require_retrieval class RagModelSaveLoadTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + # clean-up as much as possible GPU memory occupied by PyTorch + gc.collect() + torch.cuda.empty_cache() + def get_rag_config(self): question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base") generator_config = AutoConfig.from_pretrained("facebook/bart-large-cnn")