From 707023d15590ab2776df728d3a499a4f05d0a726 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 9 Jun 2023 15:03:11 +0200 Subject: [PATCH] Fix TF Rag OOM issue (#24122) fix Co-authored-by: ydshieh --- tests/models/rag/test_modeling_tf_rag.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/models/rag/test_modeling_tf_rag.py b/tests/models/rag/test_modeling_tf_rag.py index b4720f7c7f..a7edf6e0f1 100644 --- a/tests/models/rag/test_modeling_tf_rag.py +++ b/tests/models/rag/test_modeling_tf_rag.py @@ -1,5 +1,6 @@ from __future__ import annotations +import gc import json import os import shutil @@ -550,6 +551,11 @@ class TFRagDPRBartTest(TFRagTestMixin, unittest.TestCase): @require_sentencepiece @require_tokenizers class TFRagModelIntegrationTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + # clean-up as much as possible GPU memory occupied by PyTorch + gc.collect() + @cached_property def token_model(self): return TFRagTokenForGeneration.from_pretrained_question_encoder_generator(