Fix TF Rag OOM issue (#24122)

fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-06-09 15:03:11 +02:00
committed by GitHub
parent f2b918356c
commit 707023d155

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import gc
import json
import os
import shutil
@@ -550,6 +551,11 @@ class TFRagDPRBartTest(TFRagTestMixin, unittest.TestCase):
@require_sentencepiece
@require_tokenizers
class TFRagModelIntegrationTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
@cached_property
def token_model(self):
return TFRagTokenForGeneration.from_pretrained_question_encoder_generator(