Fix Pipeline CI OOM issue (#24124)
* fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -507,6 +508,10 @@ class PipelineUtilsTest(unittest.TestCase):
|
||||
|
||||
self.check_default_pipeline(task, "pt", set_seed_fn, self.check_models_equal_pt)
|
||||
|
||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@slow
|
||||
@require_tf
|
||||
def test_load_default_pipelines_tf(self):
|
||||
@@ -522,6 +527,9 @@ class PipelineUtilsTest(unittest.TestCase):
|
||||
|
||||
self.check_default_pipeline(task, "tf", set_seed_fn, self.check_models_equal_tf)
|
||||
|
||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||
gc.collect()
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_load_default_pipelines_pt_table_qa(self):
|
||||
@@ -530,6 +538,10 @@ class PipelineUtilsTest(unittest.TestCase):
|
||||
set_seed_fn = lambda: torch.manual_seed(0) # noqa: E731
|
||||
self.check_default_pipeline("table-question-answering", "pt", set_seed_fn, self.check_models_equal_pt)
|
||||
|
||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@slow
|
||||
@require_tf
|
||||
@require_tensorflow_probability
|
||||
@@ -539,6 +551,9 @@ class PipelineUtilsTest(unittest.TestCase):
|
||||
set_seed_fn = lambda: tf.random.set_seed(0) # noqa: E731
|
||||
self.check_default_pipeline("table-question-answering", "tf", set_seed_fn, self.check_models_equal_tf)
|
||||
|
||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||
gc.collect()
|
||||
|
||||
def check_default_pipeline(self, task, framework, set_seed_fn, check_models_equal_fn):
|
||||
from transformers.pipelines import SUPPORTED_TASKS, pipeline
|
||||
|
||||
|
||||
Reference in New Issue
Block a user