Fix Pipeline CI OOM issue (#24124)

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-06-09 16:49:02 +02:00
committed by GitHub
parent a7501f6fc6
commit d0d1632958
3 changed files with 44 additions and 1 deletions

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import logging
import os
import sys
@@ -507,6 +508,10 @@ class PipelineUtilsTest(unittest.TestCase):
self.check_default_pipeline(task, "pt", set_seed_fn, self.check_models_equal_pt)
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
@slow
@require_tf
def test_load_default_pipelines_tf(self):
@@ -522,6 +527,9 @@ class PipelineUtilsTest(unittest.TestCase):
self.check_default_pipeline(task, "tf", set_seed_fn, self.check_models_equal_tf)
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
@slow
@require_torch
def test_load_default_pipelines_pt_table_qa(self):
@@ -530,6 +538,10 @@ class PipelineUtilsTest(unittest.TestCase):
set_seed_fn = lambda: torch.manual_seed(0) # noqa: E731
self.check_default_pipeline("table-question-answering", "pt", set_seed_fn, self.check_models_equal_pt)
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
@slow
@require_tf
@require_tensorflow_probability
@@ -539,6 +551,9 @@ class PipelineUtilsTest(unittest.TestCase):
set_seed_fn = lambda: tf.random.set_seed(0) # noqa: E731
self.check_default_pipeline("table-question-answering", "tf", set_seed_fn, self.check_models_equal_tf)
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
def check_default_pipeline(self, task, framework, set_seed_fn, check_models_equal_fn):
from transformers.pipelines import SUPPORTED_TASKS, pipeline