Fix Pipeline CI OOM issue (#24124)
* fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@@ -507,6 +508,10 @@ class PipelineUtilsTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.check_default_pipeline(task, "pt", set_seed_fn, self.check_models_equal_pt)
|
self.check_default_pipeline(task, "pt", set_seed_fn, self.check_models_equal_pt)
|
||||||
|
|
||||||
|
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_tf
|
@require_tf
|
||||||
def test_load_default_pipelines_tf(self):
|
def test_load_default_pipelines_tf(self):
|
||||||
@@ -522,6 +527,9 @@ class PipelineUtilsTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.check_default_pipeline(task, "tf", set_seed_fn, self.check_models_equal_tf)
|
self.check_default_pipeline(task, "tf", set_seed_fn, self.check_models_equal_tf)
|
||||||
|
|
||||||
|
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_load_default_pipelines_pt_table_qa(self):
|
def test_load_default_pipelines_pt_table_qa(self):
|
||||||
@@ -530,6 +538,10 @@ class PipelineUtilsTest(unittest.TestCase):
|
|||||||
set_seed_fn = lambda: torch.manual_seed(0) # noqa: E731
|
set_seed_fn = lambda: torch.manual_seed(0) # noqa: E731
|
||||||
self.check_default_pipeline("table-question-answering", "pt", set_seed_fn, self.check_models_equal_pt)
|
self.check_default_pipeline("table-question-answering", "pt", set_seed_fn, self.check_models_equal_pt)
|
||||||
|
|
||||||
|
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_tf
|
@require_tf
|
||||||
@require_tensorflow_probability
|
@require_tensorflow_probability
|
||||||
@@ -539,6 +551,9 @@ class PipelineUtilsTest(unittest.TestCase):
|
|||||||
set_seed_fn = lambda: tf.random.set_seed(0) # noqa: E731
|
set_seed_fn = lambda: tf.random.set_seed(0) # noqa: E731
|
||||||
self.check_default_pipeline("table-question-answering", "tf", set_seed_fn, self.check_models_equal_tf)
|
self.check_default_pipeline("table-question-answering", "tf", set_seed_fn, self.check_models_equal_tf)
|
||||||
|
|
||||||
|
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
def check_default_pipeline(self, task, framework, set_seed_fn, check_models_equal_fn):
|
def check_default_pipeline(self, task, framework, set_seed_fn, check_models_equal_fn):
|
||||||
from transformers.pipelines import SUPPORTED_TASKS, pipeline
|
from transformers.pipelines import SUPPORTED_TASKS, pipeline
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import gc
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -29,7 +30,14 @@ from transformers import (
|
|||||||
TFAutoModelForCausalLM,
|
TFAutoModelForCausalLM,
|
||||||
pipeline,
|
pipeline,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch, slow, torch_device
|
from transformers.testing_utils import (
|
||||||
|
is_pipeline_test,
|
||||||
|
is_torch_available,
|
||||||
|
require_tf,
|
||||||
|
require_torch,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
from .test_pipelines_common import ANY
|
from .test_pipelines_common import ANY
|
||||||
|
|
||||||
@@ -39,6 +47,15 @@ DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
|
|||||||
|
|
||||||
@is_pipeline_test
|
@is_pipeline_test
|
||||||
class ConversationalPipelineTests(unittest.TestCase):
|
class ConversationalPipelineTests(unittest.TestCase):
|
||||||
|
def tearDown(self):
|
||||||
|
super().tearDown()
|
||||||
|
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||||
|
gc.collect()
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
model_mapping = dict(
|
model_mapping = dict(
|
||||||
list(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items())
|
list(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items())
|
||||||
if MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
if MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||||
|
|||||||
@@ -12,12 +12,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import gc
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import MODEL_FOR_MASKED_LM_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING, FillMaskPipeline, pipeline
|
from transformers import MODEL_FOR_MASKED_LM_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING, FillMaskPipeline, pipeline
|
||||||
from transformers.pipelines import PipelineException
|
from transformers.pipelines import PipelineException
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_pipeline_test,
|
is_pipeline_test,
|
||||||
|
is_torch_available,
|
||||||
nested_simplify,
|
nested_simplify,
|
||||||
require_tf,
|
require_tf,
|
||||||
require_torch,
|
require_torch,
|
||||||
@@ -33,6 +35,15 @@ class FillMaskPipelineTests(unittest.TestCase):
|
|||||||
model_mapping = MODEL_FOR_MASKED_LM_MAPPING
|
model_mapping = MODEL_FOR_MASKED_LM_MAPPING
|
||||||
tf_model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING
|
tf_model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
super().tearDown()
|
||||||
|
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||||
|
gc.collect()
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
def test_small_model_tf(self):
|
def test_small_model_tf(self):
|
||||||
unmasker = pipeline(task="fill-mask", model="sshleifer/tiny-distilroberta-base", top_k=2, framework="tf")
|
unmasker = pipeline(task="fill-mask", model="sshleifer/tiny-distilroberta-base", top_k=2, framework="tf")
|
||||||
|
|||||||
Reference in New Issue
Block a user