Fix TF Rag OOM issue (#24122)
fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import gc
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
@@ -550,6 +551,11 @@ class TFRagDPRBartTest(TFRagTestMixin, unittest.TestCase):
|
|||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
class TFRagModelIntegrationTests(unittest.TestCase):
|
class TFRagModelIntegrationTests(unittest.TestCase):
|
||||||
|
def tearDown(self):
|
||||||
|
super().tearDown()
|
||||||
|
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def token_model(self):
|
def token_model(self):
|
||||||
return TFRagTokenForGeneration.from_pretrained_question_encoder_generator(
|
return TFRagTokenForGeneration.from_pretrained_question_encoder_generator(
|
||||||
|
|||||||
Reference in New Issue
Block a user