Fix PyTorch RAG tests GPU OOM (#16881)
* add torch.cuda.empty_cache in some PT RAG tests * torch.cuda.empty_cache in tearDownModule() * tearDown() * add gc.collect() Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import gc
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
@@ -195,6 +196,10 @@ class RagTestMixin:
|
|||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
shutil.rmtree(self.tmpdirname)
|
shutil.rmtree(self.tmpdirname)
|
||||||
|
|
||||||
|
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_retriever(self, config):
|
def get_retriever(self, config):
|
||||||
dataset = Dataset.from_dict(
|
dataset = Dataset.from_dict(
|
||||||
{
|
{
|
||||||
@@ -677,6 +682,12 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase):
|
|||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
@require_torch_non_multi_gpu
|
@require_torch_non_multi_gpu
|
||||||
class RagModelIntegrationTests(unittest.TestCase):
|
class RagModelIntegrationTests(unittest.TestCase):
|
||||||
|
def tearDown(self):
|
||||||
|
super().tearDown()
|
||||||
|
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def sequence_model(self):
|
def sequence_model(self):
|
||||||
return (
|
return (
|
||||||
@@ -1024,6 +1035,12 @@ class RagModelIntegrationTests(unittest.TestCase):
|
|||||||
@require_torch
|
@require_torch
|
||||||
@require_retrieval
|
@require_retrieval
|
||||||
class RagModelSaveLoadTests(unittest.TestCase):
|
class RagModelSaveLoadTests(unittest.TestCase):
|
||||||
|
def tearDown(self):
|
||||||
|
super().tearDown()
|
||||||
|
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_rag_config(self):
|
def get_rag_config(self):
|
||||||
question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
|
question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
|
||||||
generator_config = AutoConfig.from_pretrained("facebook/bart-large-cnn")
|
generator_config = AutoConfig.from_pretrained("facebook/bart-large-cnn")
|
||||||
|
|||||||
Reference in New Issue
Block a user