Fix PyTorch RAG tests GPU OOM (#16881)

* add torch.cuda.empty_cache in some PT RAG tests

* torch.cuda.empty_cache in tearDownModule()

* tearDown()

* add gc.collect()

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2022-04-25 17:33:56 +02:00
committed by GitHub
parent 3e47d19cfc
commit 32adbb26d6

View File

@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import gc
import json import json
import os import os
import shutil import shutil
@@ -195,6 +196,10 @@ class RagTestMixin:
def tearDown(self): def tearDown(self):
shutil.rmtree(self.tmpdirname) shutil.rmtree(self.tmpdirname)
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
def get_retriever(self, config): def get_retriever(self, config):
dataset = Dataset.from_dict( dataset = Dataset.from_dict(
{ {
@@ -677,6 +682,12 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase):
@require_tokenizers @require_tokenizers
@require_torch_non_multi_gpu @require_torch_non_multi_gpu
class RagModelIntegrationTests(unittest.TestCase): class RagModelIntegrationTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
@cached_property @cached_property
def sequence_model(self): def sequence_model(self):
return ( return (
@@ -1024,6 +1035,12 @@ class RagModelIntegrationTests(unittest.TestCase):
@require_torch @require_torch
@require_retrieval @require_retrieval
class RagModelSaveLoadTests(unittest.TestCase): class RagModelSaveLoadTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
def get_rag_config(self): def get_rag_config(self):
question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base") question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
generator_config = AutoConfig.from_pretrained("facebook/bart-large-cnn") generator_config = AutoConfig.from_pretrained("facebook/bart-large-cnn")