avoid calling gc.collect and cuda.empty_cache (#34514)
* update * update * update * update * update --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
@@ -29,6 +28,7 @@ from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_
|
||||
from transformers.models.dpr.tokenization_dpr import DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer
|
||||
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
|
||||
from transformers.testing_utils import (
|
||||
cleanup,
|
||||
get_tests_dir,
|
||||
require_sentencepiece,
|
||||
require_tokenizers,
|
||||
@@ -196,8 +196,7 @@ class RagTestMixin:
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
cleanup(torch_device)
|
||||
|
||||
def get_retriever(self, config):
|
||||
dataset = Dataset.from_dict(
|
||||
@@ -684,8 +683,7 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
@cached_property
|
||||
def sequence_model(self):
|
||||
@@ -1043,8 +1041,7 @@ class RagModelSaveLoadTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
def get_rag_config(self):
|
||||
question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
|
||||
|
||||
Reference in New Issue
Block a user